transpose_tile.hpp Source File

transpose_tile.hpp Source File#

Composable Kernel: transpose_tile.hpp Source File
transpose_tile.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
19
20namespace ck_tile {
21namespace detail {
22
23template <typename OutTensor, typename InTensor>
25 const InTensor& in_tensor)
26{
27 constexpr auto I0 = number<0>{};
28
29 static_assert(std::is_same_v<typename InTensor::DataType, typename OutTensor::DataType>,
30 "Data type for InTensor and OutTensor must be the same!");
31
32 using DataType = typename InTensor::DataType;
33
34 constexpr auto y_in_desc = InTensor::get_tile_distribution().get_ys_to_d_descriptor();
35 constexpr auto y_out_desc = OutTensor::get_tile_distribution().get_ys_to_d_descriptor();
36
37 // y_dim_out_to_in
38 // For swapped Hs tile case I need only get_rh_minor_to_y
39 // since rh_major are already swapped due to swapped Hs.
40 constexpr auto get_rh_minor_to_y = [](auto dstr_tensor) {
41 using DstrEncode = typename decltype(dstr_tensor.get_tile_distribution())::DstrEncode;
42
43 map<index_t, index_t> rh_minor_to_y_;
44
46 constexpr index_t rh_minor = DstrEncode::ys_to_rhs_minor_[i];
47
48 rh_minor_to_y_(rh_minor) = i;
49 });
50
51 return rh_minor_to_y_;
52 };
53
54 // In swapped Hs case <Y,X> -> <X,Y> tile
55 // we have same rh_major, but reversed rh_minor!
56 constexpr auto rh_minor_to_y_in = get_rh_minor_to_y(InTensor{});
57 constexpr auto rh_minor_to_y_out = get_rh_minor_to_y(OutTensor{});
58
59 // Is this really needed?? Should we have simple reverse here??
60 constexpr auto y_dim_out_to_in = [&] {
61 map<index_t, index_t> y_dim_out_to_in_;
62
63 for(const auto& [rh_minor, y_out] : rh_minor_to_y_out)
64 {
65 y_dim_out_to_in_(y_out) = rh_minor_to_y_in[rh_minor];
66 }
67
68 return y_dim_out_to_in_;
69 }();
70
71 constexpr index_t NDimY = InTensor::get_tile_distribution().get_num_of_dimension_y();
72 constexpr auto y_lengths = to_sequence(y_in_desc.get_lengths());
73
74 // input and output vector dim in the order of input Y dims
75 constexpr index_t y_dim_vec_in = NDimY - 1;
76 constexpr index_t y_dim_vec_out = y_dim_out_to_in[NDimY - 1];
77
78 // vector lengths
79 constexpr index_t vec_length_in = y_lengths[y_dim_vec_in];
80 constexpr index_t vec_length_out = y_lengths[y_dim_vec_out];
81
82 // # of vectors
83 constexpr index_t num_vec_in = vec_length_out;
84 constexpr index_t num_vec_out = vec_length_in;
85
86 // SFC
87 constexpr auto scalars_per_access_arr = generate_array(
88 [&](auto i) {
89 if constexpr(vec_length_in == 1)
90 return 1;
91 else
92 return (i == y_dim_vec_in || i == y_dim_vec_out) ? y_lengths[i] : 1;
93 },
95
96 constexpr auto scalars_per_access = TO_SEQUENCE(scalars_per_access_arr, NDimY);
97
98 using SFC_Y = space_filling_curve<decltype(y_lengths),
100 decltype(scalars_per_access)>;
101
102 constexpr index_t num_access = SFC_Y::get_num_of_access();
103
104 static_assert(num_access > 0, "wrong! num_access should be larger than 0");
105
106 if constexpr(num_vec_in == 1 || num_vec_out == 1)
107 {
108 // loop over SFC
109 static_for<0, num_access, 1>{}([&](auto iAccess) {
110 // data index [y0, y1, ...] in the order of input tensor
111 constexpr auto idx_y_start = SFC_Y::get_index(iAccess);
112 constexpr auto idx_y_in =
113 generate_tuple([&](auto ii) { return idx_y_start[ii].value; }, number<NDimY>{});
114 constexpr index_t in_offset = y_in_desc.calculate_offset(idx_y_in);
115 static_assert(in_offset % vec_length_in == 0);
116 constexpr auto idx_y_out_tmp =
117 generate_array([&](auto ii) { return idx_y_start[ii].value; }, number<NDimY>{});
118 constexpr auto idx_y_out =
119 container_reorder_given_new2old(idx_y_out_tmp, y_dim_out_to_in);
120 constexpr index_t out_offset = y_out_desc.calculate_offset(idx_y_out);
121 if constexpr(vec_length_in == 1)
122 {
123
124 out_tensor.get_thread_buffer()[number<out_offset>{}] =
125 in_tensor.get_thread_buffer()[number<in_offset>{}];
126 }
127 else
128 {
130 out_tensor.get_thread_buffer().template get_as<Vec>(
132 in_tensor.get_thread_buffer().template get_as<Vec>(
134 }
135 });
136 }
137 else
138 {
139 using InVec = array<DataType, vec_length_in>;
140 using OutVec = array<DataType, vec_length_out>;
141
142 // in/out vectors to be transposed
145
146 // loop over SFC and do transpose
147 static_for<0, num_access, 1>{}([&](auto iAccess) {
148 // data index [y0, y1, ...] in the order of input tensor
149 constexpr auto idx_y_start = SFC_Y::get_index(iAccess);
150
151 // get input vectors
152 static_for<0, num_vec_in, 1>{}([&](auto i) {
153 constexpr auto idx_y_in = generate_tuple(
154 [&](auto ii) {
155 return ii == y_dim_vec_out ? idx_y_start[ii] + i : idx_y_start[ii];
156 },
157 number<NDimY>{});
158
159 constexpr index_t in_offset = y_in_desc.calculate_offset(idx_y_in);
160 static_assert(in_offset % vec_length_in == 0);
161
162 in_vectors(i).template get_as<InVec>()(I0) =
163 in_tensor.get_thread_buffer()
164 .template get_as<InVec>()[number<in_offset / vec_length_in>{}];
165 });
166
167 // transpose
169
170 // set output vectors
171 static_for<0, num_vec_out, 1>{}([&](auto i) {
172 constexpr auto idx_y_out_tmp = generate_array(
173 [&](auto ii) {
174 return ii == y_dim_vec_in ? idx_y_start[ii] + i : idx_y_start[ii];
175 },
176 number<NDimY>{});
177
178 constexpr auto idx_y_out =
179 container_reorder_given_new2old(idx_y_out_tmp, y_dim_out_to_in);
180
181 constexpr index_t out_offset = y_out_desc.calculate_offset(idx_y_out);
182 static_assert(out_offset % vec_length_out == 0);
183
184 out_tensor.get_thread_buffer().template set_as<OutVec>(
186 out_vectors[i].template get_as<OutVec>()[I0]);
187 });
188 });
189 }
190}
191
192} // namespace detail
193
194template <typename OutTensor, typename InTensor>
195CK_TILE_DEVICE void transpose_tile2d(OutTensor& out, const InTensor& in)
196{
197 using InDataType = typename InTensor::DataType;
198 using OutDataType = typename OutTensor::DataType;
199
200 using InTileDistr = typename InTensor::StaticTileDistribution;
201 using OutTileDistr = typename OutTensor::StaticTileDistribution;
202
203 using InDstrEncode = typename InTileDistr::DstrEncode;
204 using OutDstrEncode = typename OutTileDistr::DstrEncode;
205
206 using InThreadTensorDesc = typename InTensor::ThreadTensorDesc;
207 using OutThreadTensorDesc = typename OutTensor::ThreadTensorDesc;
208
209 // Ys:
210 constexpr auto in_thread_desc_lengths = InThreadTensorDesc{}.get_lengths();
211 constexpr auto out_thread_desc_lengths = OutThreadTensorDesc{}.get_lengths();
212
213 // type convert
214 const auto in_tmp = [&]() {
215 if constexpr(std::is_same_v<OutDataType, InDataType>)
216 {
217 return in;
218 }
219 else
220 {
222 }
223 }();
224
225 // Scenario where we switch from tile <Y, X> -> <X, Y> - only 2D tiles!
226 // we preserve Ps but swap Ys: <Y1, Y0> -> <Y0, Y1>
227 if constexpr(InDstrEncode::rs_lengths_ == OutDstrEncode::rs_lengths_ &&
228 InDstrEncode::hs_lengthss_ == tuple_reverse(OutDstrEncode::hs_lengthss_) &&
229 InDstrEncode::NDimY == OutDstrEncode::NDimY && InDstrEncode::NDimY == 2 &&
230 in_thread_desc_lengths == tuple_reverse(out_thread_desc_lengths))
231 // Any condition on Ps ??
232 // InDstrEncode::ps_to_rhss_major_ == OutDstrEncode::ps_to_rhss_major_ &&
233 // InDstrEncode::ps_to_rhss_minor_ == OutDstrEncode::ps_to_rhss_minor_ &&
234 {
236 }
237 else
238 {
239 static_assert(false, "Provided tensors could not be transposed!");
240 }
241}
242
243} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
Definition arch.hpp:385
CK_TILE_DEVICE void transpose_tile2d_impl_in_thread(OutTensor &out_tensor, const InTensor &in_tensor)
Definition transpose_tile.hpp:24
Definition tile/core/algorithm/cluster_descriptor.hpp:13
CK_TILE_HOST_DEVICE constexpr auto container_reorder_given_new2old(const array< TData, NSize > &old_array, sequence< IRs... >)
Definition tile/core/container/container_helper.hpp:39
CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc &in_element_func, const InTensor &... in_dstr_tensors)
Definition tile_elementwise.hpp:40
CK_TILE_HOST_DEVICE constexpr auto tuple_reverse(const tuple< Ts... > &t)
Definition tile/core/container/tuple.hpp:583
CK_TILE_DEVICE void transpose_tile2d(OutTensor &out, const InTensor &in)
Definition transpose_tile.hpp:195
CK_TILE_HOST_DEVICE constexpr auto generate_array(F &&f, number< N >)
Definition tile/core/container/sequence.hpp:1115
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_HOST_DEVICE constexpr auto generate_tuple(F &&f, number< N >)
Definition tile/core/container/tuple.hpp:429
CK_TILE_HOST_DEVICE constexpr auto to_sequence(tuple< number< Is >... >)
Definition tile/core/container/sequence.hpp:1055
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
Definition tile/core/numeric/type_convert.hpp:29
typename std::conditional< kHasContent, type0, type1 >::type type
Definition tile/core/container/sequence.hpp:302
A fixed-size array container similar to std::array with additional utilities.
Definition tile/core/container/array.hpp:43
Definition map.hpp:16
Definition space_filling_curve.hpp:20
Definition tile/core/utility/functional.hpp:43
Definition tile/core/utility/debug.hpp:67
Definition tile/core/utility/transpose_vectors.hpp:20
#define TO_SEQUENCE(a, n)
Definition to_sequence.hpp:10