tile_elementwise.hpp Source File

tile_elementwise.hpp Source File#

Composable Kernel: tile_elementwise.hpp Source File
tile_elementwise.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
15
16namespace ck_tile {
17
18// TODO: support tensors with different distribution
19template <typename InOutElementFunc,
20 typename... InOutDstrTensors,
21 typename = std::enable_if_t<std::conjunction_v<
22 std::negation<std::is_same<std::remove_const_t<InOutDstrTensors>, null_tensor>>...>>>
23CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc& inout_element_func,
24 InOutDstrTensors&... inout_dstr_tensors)
25{
26 // TODO: make sure all distributed tensors have same lengths and distribution
27 // static_assert(xxx);
28
29 constexpr index_t thread_buffer_size =
30 __type_pack_element<0, InOutDstrTensors...>::get_thread_buffer_size();
31
33 [&](auto i) { inout_element_func(inout_dstr_tensors.get_thread_buffer().at(i)...); });
34}
35
36template <typename InElementFunc,
37 typename... InTensor,
38 typename = std::enable_if_t<
39 std::conjunction_v<std::negation<std::is_same<InTensor, null_tensor>>...>>>
40CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc& in_element_func,
41 const InTensor&... in_dstr_tensors)
42{
43 using OutDataType = decltype(in_element_func(typename InTensor::DataType{}...));
44
45 // TODO: make sure all distributed tensors have same lengths and distribution
46 // static_assert(xxx);
47 constexpr auto in_tile_dstr = __type_pack_element<0, InTensor...>::get_tile_distribution();
48
49 constexpr index_t thread_buffer_size =
50 __type_pack_element<0, InTensor...>::get_thread_buffer_size();
51
52 auto out_dstr_tensor = make_static_distributed_tensor<OutDataType>(in_tile_dstr);
53
55 out_dstr_tensor.get_thread_buffer()(i) =
56 in_element_func(in_dstr_tensors.get_thread_buffer()[i]...);
57 });
58
59 return out_dstr_tensor;
60}
61
70template <typename InElementFunc, typename Tuple, size_t... I>
71CK_TILE_DEVICE auto tile_elementwise_inout_unpack(const InElementFunc& in_element_func,
72 const Tuple& t,
73 std::index_sequence<I...>)
74{
75 return tile_elementwise_inout(in_element_func, t[number<I>{}]...);
76}
77
86template <typename InElementFunc, typename Tuple>
87CK_TILE_DEVICE auto tile_elementwise_inout_unpack(const InElementFunc& in_element_func,
88 const Tuple& t)
89{
90 static constexpr auto size = Tuple::size();
91 return tile_elementwise_inout_unpack(in_element_func, t, std::make_index_sequence<size>{});
92}
93
94template <typename DstrTensors, typename T>
95CK_TILE_DEVICE void set_tile(DstrTensors& dstr_tensor, const T& value)
96{
98 [&value](auto& x) {
100 },
101 dstr_tensor);
102}
103
104template <typename T>
106{
107}
108
109// TODO: prefer to use per-dword value to set a tensor, in case compiler not doing well with
110// sub-dword tensor...
111template <typename DstrTensors, index_t v, bool skip_subdword_opt = false>
114{
115 using elem_type = typename DstrTensors::DataType;
116 constexpr index_t elem_size = sizeof(elem_type);
117
118 constexpr index_t tensor_bytes = DstrTensors::get_thread_buffer_size() * elem_size;
119
120 // # bytes per write = 4
121 if constexpr(v == 0 && tensor_bytes % 4 == 0 && !skip_subdword_opt)
122 {
123#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE
124 auto& buffer = dstr_tensor.get_thread_buffer();
125
126 static_for<0, tensor_bytes / 4, 1>{}([&](auto i_write) {
127 if constexpr(elem_size == 1)
128 {
129 // # elements per write = 4
130 constexpr auto values = ext_vector_t<elem_type, 4>{0, 0, 0, 0};
131
132 buffer[i_write * 4 + 0] = values.x;
133 buffer[i_write * 4 + 1] = values.y;
134 buffer[i_write * 4 + 2] = values.z;
135 buffer[i_write * 4 + 3] = values.w;
136 }
137 else if constexpr(elem_size == 2)
138 {
139 // # elements per write = 2
140 constexpr auto values = ext_vector_t<elem_type, 2>{0, 0};
141
142 buffer[i_write * 2 + 0] = values.x;
143 buffer[i_write * 2 + 1] = values.y;
144 }
145 else if constexpr(elem_size == 4)
146 {
147 // # elements per write = 1
148 constexpr elem_type value = 0;
149
150 buffer[i_write] = value;
151 }
152 else
153 {
154 static_assert(false, "type not supported");
155 }
156 });
157#else
158 using dvec_t = array<index_t, tensor_bytes / 4>;
159 auto& tensor = reinterpret_cast<dvec_t&>(dstr_tensor.get_thread_buffer());
160 for(auto i = 0; i < tensor.size(); i++)
161 tensor.get(i) = v;
162#endif
163 }
164 else
165 {
167 dstr_tensor);
168 }
169}
170
171template <index_t v>
175
176template <typename DstrTensors>
177CK_TILE_DEVICE void clear_tile(DstrTensors& dstr_tensor)
178{
179 set_tile(dstr_tensor, 0);
180}
181
182namespace impl {
183// TODO: this is ugly
184template <typename OutDataType, typename InTensor>
185CK_TILE_DEVICE auto cast_tile_pk_fp8_fp32(const InTensor& in_dstr_tensors)
186{
187#if defined(__gfx94__) || defined(__gfx12__)
188 // This API is designed to use the _pk_ serious of function
189 constexpr auto in_tile_dstr = InTensor::get_tile_distribution();
190
191 constexpr index_t thread_buffer_size = InTensor::get_thread_buffer_size();
192 static_assert(thread_buffer_size % 4 == 0);
193 constexpr index_t thread_buffer_size_pk = thread_buffer_size / 4;
194
195 auto out_dstr_tensor = make_static_distributed_tensor<OutDataType>(in_tile_dstr);
196#pragma clang diagnostic push
197#pragma clang diagnostic ignored "-Wuninitialized"
198 // __builtin_amdgcn_cvt_pk_fp8_f32() this builtin requires the old value, and
199 // will generate a v_mov_b32 vxxx [old] before cvt, which result in unwanted ISA
200 // so we prepare an uninitialized variable purposely, and turn off the warning
201 int dummy_old;
203 uint32_t x = __builtin_amdgcn_cvt_pk_fp8_f32(
204 in_dstr_tensors.get_thread_buffer()[number<4 * i + 0>{}],
205 in_dstr_tensors.get_thread_buffer()[number<4 * i + 1>{}],
206 dummy_old,
207 false); // false -> WORD0
208
209 uint32_t y = __builtin_amdgcn_cvt_pk_fp8_f32(
210 in_dstr_tensors.get_thread_buffer()[number<4 * i + 2>{}],
211 in_dstr_tensors.get_thread_buffer()[number<4 * i + 3>{}],
212 x,
213 true); // true -> WORD1
214
215 using vec_t = array<OutDataType, 4>;
216
217 vec_t d = bit_cast<vec_t>(y);
218 out_dstr_tensor.get_thread_buffer().template set_as<vec_t>(number<i>{}, d);
219 });
220#pragma clang diagnostic pop
221
222 return out_dstr_tensor;
223#else
224 // fallback
226 in_dstr_tensors);
227#endif
228}
229
230template <typename OutDataType, typename InTensor>
231CK_TILE_DEVICE auto cast_tile_pk_fp16_fp32(const InTensor& in_dstr_tensors)
232{
233#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__)
234 // This API is designed to use the _pk_ serious of function
235 constexpr auto in_tile_dstr = InTensor::get_tile_distribution();
236
237 constexpr index_t thread_buffer_size = InTensor::get_thread_buffer_size();
238 static_assert(thread_buffer_size % 2 == 0);
239 constexpr index_t thread_buffer_size_pk = thread_buffer_size / 2;
240
241 auto out_dstr_tensor = make_static_distributed_tensor<OutDataType>(in_tile_dstr);
242
243 // TODO: this is rtz cvt, need be very careful
244 for(index_t i = 0; i < thread_buffer_size_pk; i++)
245 {
246 auto o = __builtin_amdgcn_cvt_pkrtz(in_dstr_tensors.get_thread_buffer()[2 * i + 0],
247 in_dstr_tensors.get_thread_buffer()[2 * i + 1]);
248
249 out_dstr_tensor.get_thread_buffer().at(2 * i + 0) = o.x;
250 out_dstr_tensor.get_thread_buffer().at(2 * i + 1) = o.y;
251 }
252
253 return out_dstr_tensor;
254#else
255 // fallback
257 in_dstr_tensors);
258#endif
259}
260
261#if CK_TILE_USE_SUBDWORD_TILE_CAST
262// this function assume either src or dst (or both) date type is under 1 dword
263// we pack subdword value into 1 dword to avoid compiler's default subdword behavior(which is buggy)
264template <typename OutDataType, typename InTensor>
265CK_TILE_DEVICE auto cast_tile_opt_subdword(const InTensor& in_dstr_tensors)
266{
267 constexpr auto in_tile_dstr = InTensor::get_tile_distribution();
268
269 auto out_dstr_tensor = make_static_distributed_tensor<OutDataType>(in_tile_dstr);
270
272 using o_type = remove_cvref_t<OutDataType>;
273 constexpr index_t i_elem_bytes = sizeof(i_type);
274 constexpr index_t o_elem_bytes = sizeof(o_type);
275 static_assert(i_elem_bytes < 4 || o_elem_bytes < 4);
276
277 constexpr index_t bulk_size =
278 (i_elem_bytes >= o_elem_bytes) ? (4 / o_elem_bytes) : (4 / i_elem_bytes);
279 static_assert(bulk_size != 0);
280
281 using o_bulk_type =
282 std::conditional_t<i_elem_bytes >= o_elem_bytes, float, array<o_type, bulk_size>>;
283
284 constexpr index_t thread_buffer_size = InTensor::get_thread_buffer_size();
285
286 constexpr index_t iters = thread_buffer_size / bulk_size;
287 constexpr index_t rems = thread_buffer_size % bulk_size;
288
289 // cast the sequence per-bulk
290 static_for<0, iters, 1>{}([&](auto i) {
291 union bulk_wrapper
292 {
293 o_bulk_type bulk{};
294 o_type data[bulk_size];
295 } o_bulk;
296
297 // TODO: should use below function, but somehow will result in spill (same as c-forloop)
298 static_for<0, bulk_size, 1>{}([&o_bulk, &in_dstr_tensors, &i](auto ib) {
299 o_bulk.data[ib.value] = static_cast<o_type>(
300 in_dstr_tensors.get_thread_buffer()
301 .template get_as<i_type>()[number<bulk_size * i.value + ib.value>{}]);
302 });
303
304 // TODO: fixme, should use above!
305 // static_assert(sizeof(i_type) / sizeof(o_type) == 2);
306 // o_bulk.data[0] = static_cast<o_type>(
307 // in_dstr_tensors.get_thread_buffer().template get_as<i_type>()[number<2 * i + 0>{}]);
308 // o_bulk.data[1] = static_cast<o_type>(
309 // in_dstr_tensors.get_thread_buffer().template get_as<i_type>()[number<2 * i + 1>{}]);
310
311 out_dstr_tensor.get_thread_buffer().template set_as<o_bulk_type>(i, o_bulk.bulk);
312 });
313
314 static_for<0, rems, 1>{}([&](auto r) {
315 // TODO: introducing local scratch pad?
317 out_dstr_tensor.get_thread_buffer().at(idx) =
318 static_cast<o_type>(in_dstr_tensors.get_thread_buffer().at(idx));
319 });
320
321 return out_dstr_tensor;
322}
323#endif
324} // namespace impl
325
326template <typename DstType, typename SrcTensor>
327CK_TILE_DEVICE auto cast_tile(const SrcTensor& src_tensor)
328{
329 if constexpr((std::is_same_v<DstType, fp8_t> || std::is_same_v<DstType, bf8_t>) &&
330 std::is_same_v<typename SrcTensor::DataType, float> &&
331 (SrcTensor::get_thread_buffer_size() % 4 == 0))
332 {
334 }
335#if CK_TILE_USE_PK_FP16_TILE_CAST
336 else if constexpr(std::is_same_v<DstType, fp16_t> &&
337 std::is_same_v<typename SrcTensor::DataType, float> &&
338 (SrcTensor::get_thread_buffer_size() % 2 == 0))
339 {
341 }
342#endif
343#if CK_TILE_USE_SUBDWORD_TILE_CAST
344 else if constexpr(sizeof(DstType) < 4 || sizeof(typename SrcTensor::DataType) < 4)
345 {
346 return impl::cast_tile_opt_subdword<DstType, SrcTensor>(src_tensor);
347 }
348#endif
349 else
351}
352
353// no-op function for null_tensor arguments
354template <typename InOutElementFunc,
355 typename... MaybeNullTensor,
356 typename = std::enable_if_t<
357 std::disjunction_v<std::is_same<remove_cvref_t<MaybeNullTensor>, null_tensor>...>>>
358CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc&, MaybeNullTensor&&...)
359{
360}
361
362// no-op function for null_tensor arguments
363template <typename InElementFunc,
364 typename... MaybeNullTensor,
365 typename = std::enable_if_t<
366 std::disjunction_v<std::is_same<remove_cvref_t<MaybeNullTensor>, null_tensor>...>>>
367CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc&, MaybeNullTensor&&...)
368{
369 return null_tensor{};
370}
371
372} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
CK_TILE_DEVICE auto cast_tile_pk_fp8_fp32(const InTensor &in_dstr_tensors)
Definition tile_elementwise.hpp:185
CK_TILE_DEVICE auto cast_tile_pk_fp16_fp32(const InTensor &in_dstr_tensors)
Definition tile_elementwise.hpp:231
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc &in_element_func, const InTensor &... in_dstr_tensors)
Definition tile_elementwise.hpp:40
CK_TILE_DEVICE void set_tile(DstrTensors &dstr_tensor, const T &value)
Definition tile_elementwise.hpp:95
constant< b > bool_constant
Definition tile/core/numeric/integral_constant.hpp:43
CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc &inout_element_func, InOutDstrTensors &... inout_dstr_tensors)
Definition tile_elementwise.hpp:23
CK_TILE_HOST_DEVICE constexpr Y bit_cast(const X &x)
Definition bit_cast.hpp:11
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution &)
Definition static_distributed_tensor.hpp:142
CK_TILE_DEVICE auto tile_elementwise_inout_unpack(const InElementFunc &in_element_func, const Tuple &t, std::index_sequence< I... >)
Template function that "unpacks" a tuple and applies an element-wise operation.
Definition tile_elementwise.hpp:71
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE auto cast_tile(const SrcTensor &src_tensor)
Definition tile_elementwise.hpp:327
typename impl::ext_vector< T, N >::type ext_vector_t
Definition vector_type.hpp:84
int32_t index_t
Definition integer.hpp:9
CK_TILE_DEVICE void clear_tile(DstrTensors &dstr_tensor)
Definition tile_elementwise.hpp:177
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
Definition tile/core/numeric/type_convert.hpp:29
const GenericPointer< typename T::ValueType > T2 value
Definition pointer.h:1697
unsigned int uint32_t
Definition stdint.h:126
A fixed-size array container similar to std::array with additional utilities.
Definition tile/core/container/array.hpp:43
Definition null_tensor.hpp:9
Definition tile/core/utility/functional.hpp:43