block_fmha_fwd_splitkv_combine_pipeline.hpp Source File

block_fmha_fwd_splitkv_combine_pipeline.hpp Source File#

Composable Kernel: block_fmha_fwd_splitkv_combine_pipeline.hpp Source File
block_fmha_fwd_splitkv_combine_pipeline.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
9
10namespace ck_tile {
11namespace detail {
12template <index_t N>
13struct log2;
14
15template <>
16struct log2<4> : std::integral_constant<index_t, 2>
17{
18};
19
20template <>
21struct log2<8> : std::integral_constant<index_t, 3>
22{
23};
24
25template <>
26struct log2<16> : std::integral_constant<index_t, 4>
27{
28};
29
30template <>
31struct log2<32> : std::integral_constant<index_t, 5>
32{
33};
34
35template <>
36struct log2<64> : std::integral_constant<index_t, 6>
37{
38};
39
40template <>
41struct log2<128> : std::integral_constant<index_t, 7>
42{
43};
44} // namespace detail
45
46template <typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy>
48{
51
55
56 static constexpr index_t kNumWarps = Problem::kNumWarps;
57 static constexpr index_t kBlockSize = Problem::kBlockSize;
58
59 static constexpr index_t kHeadDimV = Problem::kHeadDimV;
60 static constexpr index_t kM0 = Problem::kM0;
61 static constexpr index_t kN1 = Problem::kN1;
62
63 static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
64 static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
65 static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
66 static constexpr bool kStoreLSE = Problem::kStoreLSE;
67 static constexpr index_t kMaxSplits = Problem::kMaxSplits;
68
69 static constexpr index_t kAlignmentLSE =
70 kPadSeqLenQ ? 1 : Policy::template GetAlignmentLSE<Problem>();
72
73 static constexpr index_t kAlignmentOacc =
74 kPadHeadDimV ? 1 : Policy::template GetAlignmentOacc<Problem>();
75
76 static constexpr index_t kAlignmentO =
77 kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
78
79 static constexpr index_t kBlockPerCu = []() {
80 if constexpr(Problem::kBlockPerCu != -1)
81 return Problem::kBlockPerCu;
82 else
83 {
84 if constexpr(kHeadDimV <= 32)
85 {
86 constexpr std::array occupancy{3, 3, 3, 3, 3, 1};
87 return occupancy[detail::log2<kMaxSplits>::value - 2];
88 }
89 else if constexpr(kHeadDimV <= 128)
90 {
91 constexpr std::array occupancy{3, 3, 3, 3, 2, 1};
92 return occupancy[detail::log2<kMaxSplits>::value - 2];
93 }
94 else if constexpr(kHeadDimV <= 256)
95 {
96 constexpr std::array occupancy{2, 2, 2, 2, 2, 1};
97 return occupancy[detail::log2<kMaxSplits>::value - 2];
98 }
99 }
100 }();
101
102 static constexpr const char* name = "unused";
103
105 {
106 return Policy::template GetSmemSize<Problem>();
107 }
108
109 template <typename LSEaccDramBlockWindowTmp,
110 typename OaccDramBlockWindowTmp,
111 typename LSEDramBlockWindowTmp,
112 typename LSEElementFunction,
113 typename OaccElementFunction>
115 operator()(const LSEaccDramBlockWindowTmp& lse_acc_dram_block_window_tmp,
116 const OaccDramBlockWindowTmp& o_acc_dram_block_window_tmp,
117 LSEDramBlockWindowTmp& lse_dram_window_tmp,
118 const LSEElementFunction& lse_element_func,
119 const OaccElementFunction& o_acc_element_func,
120 index_t num_splits,
121 void* smem_ptr) const
122 {
123 // lse_acc tile in LDS
124 LSEDataType* lse_acc_lds_ptr =
125 static_cast<LSEDataType*>(static_cast<void*>(static_cast<char*>(smem_ptr)));
126 auto lse_acc_lds = [=, lds_desc = Policy::template MakeLSEaccLdsBlockDescriptor<Problem>()](
127 index_t row, index_t col) -> LSEDataType& {
128 return lse_acc_lds_ptr[lds_desc.calculate_offset(make_tuple(row, col))];
129 };
130
131 auto lse_acc_lds_write_window = [&]() {
133 lse_acc_lds_ptr, Policy::template MakeLSEaccLdsStoreBlockDescriptor<Problem>());
134 return make_tile_window(view, make_tuple(number<kMaxSplits>{}, number<kM0>{}), {0, 0});
135 }();
136
137 auto lse_acc_dram_window =
138 make_tile_window(lse_acc_dram_block_window_tmp.get_bottom_tensor_view(),
139 lse_acc_dram_block_window_tmp.get_window_lengths(),
140 lse_acc_dram_block_window_tmp.get_window_origin(),
141 Policy::template MakeLSEaccDramTileDistribution<Problem>());
142
143 // copy lse_acc tile (shape=[kMaxSplits, kM0]) to LDS (shape=[kMaxSplits, kM0]).
144 auto lse_acc_tile = load_tile(lse_acc_dram_window);
145 store_tile(lse_acc_lds_write_window, lse_acc_tile);
146
148 Policy::template MakeLSEaccRegTileDistribution<Problem>());
149
150 __builtin_amdgcn_sched_barrier(0);
152 // copy LDS (shape=[kM0, kMaxSplits]) to lse_accum (shape=[kM0, kMaxSplits])
153 // and fill up -INF values outside the [kM0, num_splits] region.
154 {
155 constexpr auto spans = decltype(lse_accum)::get_distributed_spans();
156 sweep_tile_span(spans[number<0>{}], [&](auto idx0) {
157 sweep_tile_span(spans[number<1>{}], [&](auto idx1) {
158 constexpr auto i_j_idx = make_tuple(idx0, idx1);
159 const auto x_indices = get_x_indices_from_distributed_indices(
160 lse_accum.get_tile_distribution(), i_j_idx);
161
162 const auto col = x_indices.at(number<1>{});
163 if(col < num_splits)
164 {
165 const auto row = x_indices.at(number<0>{});
166
167 lse_accum(i_j_idx) = lse_acc_lds(row, col);
168 }
169 else
170 {
171 lse_accum(i_j_idx) = -numeric<LSEDataType>::infinity();
172 }
173 });
174 });
175 }
176
177 // compute the logsumexp of the LSE along the split dimension.
178 const auto f_max = [](auto e0, auto e1) { return ck_tile::max(e0, e1); };
179 const auto f_sum = [](auto e0, auto e1) { return e0 + e1; };
180
181 auto lse_max = block_tile_reduce<LSEDataType>(
182 lse_accum, sequence<1>{}, f_max, -numeric<LSEDataType>::infinity());
184
185 decltype(lse_accum) lse_exp;
186 {
187 constexpr auto spans = decltype(lse_exp)::get_distributed_spans();
188 sweep_tile_span(spans[number<0>{}], [&](auto idx0) {
189 constexpr auto i_idx = make_tuple(idx0);
190 if(lse_max[i_idx] == -numeric<LSEDataType>::infinity())
191 {
192 sweep_tile_span(spans[number<1>{}], [&](auto idx1) {
193 constexpr auto i_j_idx = make_tuple(idx0, idx1);
194
195 lse_exp(i_j_idx) = ck_tile::type_convert<LSEDataType>(0.0f);
196 });
197 }
198 else
199 {
200 sweep_tile_span(spans[number<1>{}], [&](auto idx1) {
201 constexpr auto i_j_idx = make_tuple(idx0, idx1);
202
203 lse_exp(i_j_idx) = ck_tile::exp(lse_accum(i_j_idx) - lse_max(i_idx));
204 });
205 }
206 });
207 }
208
209 auto lse_sum = block_tile_reduce<LSEDataType>(
210 lse_exp, sequence<1>{}, f_sum, type_convert<LSEDataType>(0));
212
213 decltype(lse_max) lse_logsum;
214 {
215 constexpr auto spans = decltype(lse_logsum)::get_distributed_spans();
216 sweep_tile_span(spans[number<0>{}], [&](auto idx0) {
217 constexpr auto i_idx = make_tuple(idx0);
218
219 if(lse_sum[i_idx] == ck_tile::type_convert<LSEDataType>(0.0f))
220 lse_logsum(i_idx) = -numeric<LSEDataType>::infinity();
221 else
222 lse_logsum(i_idx) = ck_tile::log(lse_sum(i_idx)) + lse_max(i_idx);
223 });
224 }
225
226 // sync before rewriting lse_acc_lds
228 // store the lse scales in shared memory.
229 {
230 constexpr auto spans = decltype(lse_accum)::get_distributed_spans();
231 sweep_tile_span(spans[number<0>{}], [&](auto idx0) {
232 constexpr auto i_idx = make_tuple(idx0);
233 if(lse_logsum(i_idx) == -numeric<LSEDataType>::infinity())
234 {
235 sweep_tile_span(spans[number<1>{}], [&](auto idx1) {
236 constexpr auto i_j_idx = make_tuple(idx0, idx1);
237
238 const auto x_indices = get_x_indices_from_distributed_indices(
239 lse_accum.get_tile_distribution(), i_j_idx);
240
241 const auto col = x_indices.at(number<1>{});
242 if(col < num_splits)
243 {
244 const auto row = x_indices.at(number<0>{});
245
246 lse_acc_lds(row, col) = ck_tile::type_convert<LSEDataType>(0.0f);
247 }
248 });
249 }
250 else
251 {
252 sweep_tile_span(spans[number<1>{}], [&](auto idx1) {
253 constexpr auto i_j_idx = make_tuple(idx0, idx1);
254
255 const auto x_indices = get_x_indices_from_distributed_indices(
256 lse_accum.get_tile_distribution(), i_j_idx);
257
258 const auto col = x_indices.at(number<1>{});
259 if(col < num_splits)
260 {
261 const auto row = x_indices.at(number<0>{});
262
263 lse_acc_lds(row, col) =
264 ck_tile::exp(lse_accum(i_j_idx) - lse_logsum(i_idx));
265 }
266 });
267 }
268 });
269 }
270
271 if constexpr(kStoreLSE)
272 {
273 store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse_logsum));
274 }
275
276 // First each warp processes its own part of splits
277
278 auto o_acc_dist = Policy::template MakeOaccDramTileDistribution<Problem>();
279 auto o_acc_dram_window =
280 make_tile_window(o_acc_dram_block_window_tmp.get_bottom_tensor_view(),
281 o_acc_dram_block_window_tmp.get_window_lengths(),
282 o_acc_dram_block_window_tmp.get_window_origin(),
283 o_acc_dist);
284
285 // shape=[kNumWarps * KM0, kN1]
286 auto o_acc = make_static_distributed_tensor<OaccDataType>(o_acc_dist);
287 clear_tile(o_acc);
288
289 const index_t padded_num_splits = integer_divide_ceil(num_splits, kNumWarps) * kNumWarps;
290
291 __builtin_amdgcn_sched_barrier(0);
293 // each warp handles a [KM0, kN1] tile
294 for(index_t split_start = 0; split_start < padded_num_splits; split_start += kNumWarps)
295 {
296 auto o_tile = load_tile(o_acc_dram_window);
297 const index_t i_split = split_start + get_warp_id();
298 const index_t row_start = kM0 * get_warp_id();
299 {
300 constexpr auto spans = decltype(o_acc)::get_distributed_spans();
301 sweep_tile_span(spans[number<0>{}], [&](auto idx0) {
302 sweep_tile_span(spans[number<1>{}], [&](auto idx1) {
303 constexpr auto i_j_idx = make_tuple(idx0, idx1);
304 const auto x_indices = get_x_indices_from_distributed_indices(
305 o_acc.get_tile_distribution(), i_j_idx);
306
307 const auto row = x_indices.at(number<0>{});
308
309 const LSEDataType lse_scale = lse_acc_lds(row - row_start, i_split);
310 o_acc(i_j_idx) += lse_scale * o_tile(i_j_idx);
311 });
312 });
313 }
314
315 move_tile_window(o_acc_dram_window, {kNumWarps * kM0, 0});
316 }
317
318 // Then each warps combines partial o_acc results into one
319
320 // kNumWarps o_acc tiles in LDS. shape=[kNumWarps * kM0, kN1]
321 OaccDataType* o_acc_lds_ptr = static_cast<OaccDataType*>(static_cast<void*>(
322 static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeLSEacc<Problem>()));
323
324 {
325 auto o_acc_lds_store_window = [&]() {
326 auto desc = Policy::template MakeOaccLdsBlockDescriptor<Problem>();
327 auto view = make_tensor_view<address_space_enum::lds>(o_acc_lds_ptr, desc);
328 return make_tile_window(view, desc.get_lengths(), {0, 0});
329 }();
330 store_tile(o_acc_lds_store_window, o_acc);
331 }
332
333 auto o_acc_result_dist = Policy::template MakeOaccResultDramTileDistribution<Problem>();
334
335 auto o_acc_lds_load_window = [&]() {
336 auto desc = Policy::template MakeOaccLdsBlockDescriptor<Problem>();
337 auto view = make_tensor_view<address_space_enum::lds>(o_acc_lds_ptr, desc);
338 return make_tile_window(view, desc.get_lengths(), {0, 0}, o_acc_result_dist);
339 }();
340
341 auto o_acc_result = make_static_distributed_tensor<OaccDataType>(o_acc_result_dist);
342 clear_tile(o_acc_result);
343
344 __builtin_amdgcn_sched_barrier(0);
346 static_for<0, kNumWarps, 1>{}([&](auto) {
347 auto o_acc_in = load_tile(o_acc_lds_load_window);
348
349 {
350 constexpr auto spans = decltype(o_acc_result)::get_distributed_spans();
351 sweep_tile_span(spans[number<0>{}], [&](auto idx0) {
352 sweep_tile_span(spans[number<1>{}], [&](auto idx1) {
353 constexpr auto i_j_idx = make_tuple(idx0, idx1);
354 o_acc_result(i_j_idx) += o_acc_in(i_j_idx);
355 });
356 });
357 }
358
359 move_tile_window(o_acc_lds_load_window, {kM0, 0});
360 });
361
362 return tile_elementwise_in(o_acc_element_func, o_acc_result);
363 }
364
365 template <typename LSEaccDramBlockWindow,
366 typename OaccDramBlockWindow,
367 typename LSEDramBlockWindow>
368 CK_TILE_HOST_DEVICE auto operator()(const LSEaccDramBlockWindow& lse_acc_dram_block_window,
369 const OaccDramBlockWindow& o_acc_dram_block_window,
370 LSEDramBlockWindow& lse_dram_block_window,
371 index_t num_splits,
372 void* smem_ptr) const
373 {
374 return operator()(lse_acc_dram_block_window,
375 o_acc_dram_block_window,
376 lse_dram_block_window,
377 identity{},
378 identity{},
379 num_splits,
380 smem_ptr);
381 }
382};
383
384} // namespace ck_tile
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition arch.hpp:385
Definition tile/core/algorithm/cluster_descriptor.hpp:13
CK_TILE_DEVICE bfloat16_t log(bfloat16_t x)
Definition bfloat16.hpp:428
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc &in_element_func, const InTensor &... in_dstr_tensors)
Definition tile_elementwise.hpp:40
CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType *__restrict__ p, const tensor_descriptor< Ts... > &desc)
Definition tensor_view.hpp:452
CK_TILE_HOST_DEVICE constexpr auto get_x_indices_from_distributed_indices(StaticTileDistribution tile_distribution, DistributedIndices distributed_indices)
Definition static_distributed_tensor.hpp:159
constant< b > bool_constant
Definition tile/core/numeric/integral_constant.hpp:43
CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_ &acc_tensor, const ReduceFunc &reduce_func, bool_constant< WithBroadcast >={}, bool_constant< CrossWarp >={})
Definition block_reduce.hpp:21
CK_TILE_DEVICE index_t get_warp_id(bool_constant< ReturnSgpr >={})
Definition arch.hpp:104
CK_TILE_DEVICE void block_sync_lds()
Definition arch.hpp:282
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution &)
Definition static_distributed_tensor.hpp:142
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
CK_TILE_DEVICE void block_tile_reduce(AccDistributedTensor_ &acc_tensor, const InDistributedTensor_ &in_tensor, sequence< InReduceDims... >, const ReduceFunc &reduce_func)
Definition block_reduce.hpp:191
CK_TILE_HOST_DEVICE constexpr auto integer_divide_ceil(X x, Y y)
Definition tile/core/numeric/math.hpp:149
CK_TILE_DEVICE bfloat16_t exp(bfloat16_t x)
Definition bfloat16.hpp:419
CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F &f)
Definition sweep_tile.hpp:20
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition null_tile_window.hpp:95
CK_TILE_HOST_DEVICE constexpr T max(T x)
Definition tile/core/numeric/math.hpp:161
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition store_tile.hpp:23
int32_t index_t
Definition integer.hpp:9
CK_TILE_DEVICE void clear_tile(DstrTensors &dstr_tensor)
Definition tile_elementwise.hpp:177
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
Definition tile/core/numeric/type_convert.hpp:29
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition load_tile.hpp:22
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
Definition block_fmha_fwd_splitkv_combine_pipeline.hpp:48
remove_cvref_t< Problem_ > Problem
Definition block_fmha_fwd_splitkv_combine_pipeline.hpp:49
CK_TILE_HOST_DEVICE auto operator()(const LSEaccDramBlockWindowTmp &lse_acc_dram_block_window_tmp, const OaccDramBlockWindowTmp &o_acc_dram_block_window_tmp, LSEDramBlockWindowTmp &lse_dram_window_tmp, const LSEElementFunction &lse_element_func, const OaccElementFunction &o_acc_element_func, index_t num_splits, void *smem_ptr) const
Definition block_fmha_fwd_splitkv_combine_pipeline.hpp:115
static constexpr index_t kAlignmentOacc
Definition block_fmha_fwd_splitkv_combine_pipeline.hpp:73
static constexpr index_t kNumWarps
Definition block_fmha_fwd_splitkv_combine_pipeline.hpp:56
remove_cvref_t< typename Problem::OaccDataType > OaccDataType
Definition block_fmha_fwd_splitkv_combine_pipeline.hpp:53
remove_cvref_t< Policy_ > Policy
Definition block_fmha_fwd_splitkv_combine_pipeline.hpp:50
static constexpr const char * name
Definition block_fmha_fwd_splitkv_combine_pipeline.hpp:102
CK_TILE_HOST_DEVICE auto operator()(const LSEaccDramBlockWindow &lse_acc_dram_block_window, const OaccDramBlockWindow &o_acc_dram_block_window, LSEDramBlockWindow &lse_dram_block_window, index_t num_splits, void *smem_ptr) const
Definition block_fmha_fwd_splitkv_combine_pipeline.hpp:368
static constexpr index_t kBlockPerCu
Definition block_fmha_fwd_splitkv_combine_pipeline.hpp:79
static constexpr index_t kMaxSplits
Definition block_fmha_fwd_splitkv_combine_pipeline.hpp:67
static constexpr bool kStoreLSE
Definition block_fmha_fwd_splitkv_combine_pipeline.hpp:66
static constexpr index_t kM0
Definition block_fmha_fwd_splitkv_combine_pipeline.hpp:60
static constexpr bool kIsGroupMode
Definition block_fmha_fwd_splitkv_combine_pipeline.hpp:63
static constexpr index_t kBlockSize
Definition block_fmha_fwd_splitkv_combine_pipeline.hpp:57
remove_cvref_t< typename Problem::ODataType > ODataType
Definition block_fmha_fwd_splitkv_combine_pipeline.hpp:54
static constexpr index_t kAlignmentO
Definition block_fmha_fwd_splitkv_combine_pipeline.hpp:76
static constexpr index_t kN1
Definition block_fmha_fwd_splitkv_combine_pipeline.hpp:61
static constexpr index_t kHeadDimV
Definition block_fmha_fwd_splitkv_combine_pipeline.hpp:59
static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t GetSmemSize()
Definition block_fmha_fwd_splitkv_combine_pipeline.hpp:104
static constexpr index_t kAlignmentLSEacc
Definition block_fmha_fwd_splitkv_combine_pipeline.hpp:71
static constexpr bool kPadHeadDimV
Definition block_fmha_fwd_splitkv_combine_pipeline.hpp:65
static constexpr bool kPadSeqLenQ
Definition block_fmha_fwd_splitkv_combine_pipeline.hpp:64
static constexpr index_t kAlignmentLSE
Definition block_fmha_fwd_splitkv_combine_pipeline.hpp:69
remove_cvref_t< typename Problem::LSEDataType > LSEDataType
Definition block_fmha_fwd_splitkv_combine_pipeline.hpp:52
Definition block_fmha_fwd_splitkv_combine_pipeline.hpp:13
Definition tile/core/utility/functional.hpp:86
static CK_TILE_HOST_DEVICE constexpr T infinity()
Definition tile/core/numeric/numeric.hpp:38
Definition tile/core/container/sequence.hpp:49
Definition tile/core/utility/functional.hpp:43