flatmm_32x512x128_1x4x1_16x16x32.hpp Source File

flatmm_32x512x128_1x4x1_16x16x32.hpp Source File#

Composable Kernel: flatmm_32x512x128_1x4x1_16x16x32.hpp Source File
flatmm_32x512x128_1x4x1_16x16x32.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
9
10namespace ck_tile {
11
12// A async load to LDS, B direct to AGPR
13// B matrix preshuffled in br*kr*w
14// require 4 wave, occupancy=1c
15// agpr useage:256
16// vgpr usage:64(A local) + 64(acc) + 8(os_a) + 8(os_b) = 144 (rem:112)
17//
18// for this gemm, 4 16x16x16 transposed layout
19// input A vpgpr layout
20// v0-v15: [ 0:15](gemm_m)x128(gemm_k)
21// v16-v31: [16:31](gemm_m)x128(gemm_k)
22
23// input B vpgpr layout
24// v0-v15: [ 0: 15](gemm_n)x128(gemm_k)
25// v16-v31: [ 64: 79](gemm_n)x128(gemm_k)
26// ......................
27// v111-v127: [448:463](gemm_n)x128(gemm_k)
28
29// output C vpgpr layout
30// v0-v3 : [ 0:15](gemm_m)x[ 0: 15](gemm_n)
31// v4-v7 : [16:31](gemm_m)x[ 0: 15](gemm_n)
32// v8-v11: [ 0:15](gemm_m)x[64: 79](gemm_n)
33// v12-v15: [16:31](gemm_m)x[64: 79](gemm_n)
34// ......................
35// v56-v59: [ 0:15](gemm_m)x[448:463](gemm_n)
36// v60-v63: [16:31](gemm_m)x[448:463](gemm_n)
38{
39 static constexpr index_t Block_M = 32;
40 static constexpr index_t Block_N = 512;
41 static constexpr index_t Block_K = 128;
42
43 static constexpr index_t WarpPerBlock_M = 1;
44 static constexpr index_t WarpPerBlock_N = 4;
45 static constexpr index_t WarpPerBlock_K = 1;
46
47 static constexpr index_t NumWarps = 4;
48
49 static constexpr index_t Warp_M = 16;
50 static constexpr index_t Warp_N = 16;
51 static constexpr index_t Warp_K = 32; // 16 * SubKPacks
52
53 static constexpr index_t BlockSize = 256;
54
55 static constexpr index_t SubKPacks = 2; // this is used to gurantee every threads can do dwordx4
56
57 // TODO: note Nr/Kr/W need consider SubKPacks
58 static constexpr index_t Block_W = Warp_N * Warp_K; // 512 element
59 static constexpr index_t Block_Nr = Block_N / Warp_N; // 32 element, 4 per wave
60 static constexpr index_t Block_Kr = Block_K / Warp_K; // 4
61
62 static constexpr index_t Repeat_M = Block_M / (Warp_M * WarpPerBlock_M); // 2
63 static constexpr index_t Repeat_N = Block_N / (Warp_N * WarpPerBlock_N); // 8
64 static constexpr index_t Repeat_K = Block_K / (Warp_K * WarpPerBlock_K); // 8/2=4
65
66 private:
67 template <index_t LanesPerK, index_t WarpSize, typename = void>
68 struct LdsStoreDescSelector;
69
70 template <index_t LanesPerK, index_t WarpSize>
71 struct LdsStoreDescSelector<LanesPerK, WarpSize, std::enable_if_t<(LanesPerK >= WarpSize)>>
72 {
73 template <index_t NumWarps, index_t Block_M, index_t Block_K, index_t KVector, index_t KPad>
74 static CK_TILE_HOST_DEVICE constexpr auto MakeDesc()
75 {
76 // need multiple waves to load K
77 static_assert(LanesPerK % WarpSize == 0);
78 constexpr index_t wavesPerK = LanesPerK / WarpSize;
79 if constexpr(wavesPerK > NumWarps)
80 {
81 // TODO: need multiple issues along K to load all data
82 }
83 else
84 {
85 constexpr index_t wavesPerM = NumWarps / wavesPerK;
86 constexpr index_t NumIssues = Block_M / wavesPerM;
87 constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
89 number<wavesPerM>{}, // m1
90 number<wavesPerK>{}, // k0
91 number<WarpSize>{}, // k1
92 number<KVector>{}), // k2
93 make_tuple(number<NumWarps*(WarpSize * KVector + KPad)>{}, // m0
94 number<wavesPerK*(WarpSize * KVector + KPad)>{}, // m1
96 number<KVector>{}, // k1
97 number<1>{}), // k2
98 number<KVector>{}, // lds store vector(actually no explicit store)
99 number<1>{});
100
101 constexpr auto lds_block_desc_issues_warps_lanes = transform_tensor_descriptor(
102 lds_block_desc_0,
107 make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}),
108 make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
109
110 return lds_block_desc_issues_warps_lanes;
111 }
112 }
113 };
114
115 template <index_t LanesPerK, index_t WarpSize>
116 struct LdsStoreDescSelector<LanesPerK, WarpSize, std::enable_if_t<(LanesPerK < WarpSize)>>
117 {
118 template <index_t NumWarps, index_t Block_M, index_t Block_K, index_t KVector, index_t KPad>
119 static CK_TILE_HOST_DEVICE constexpr auto MakeDesc()
120 {
121 // lanes within a wave load different M but same K
122 static_assert(WarpSize % LanesPerK == 0);
123 constexpr index_t LaneGroups = WarpSize / LanesPerK; // along m
124 constexpr index_t NumIssues = Block_M / (LaneGroups * NumWarps);
125
126 constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
128 number<LaneGroups>{}, // m1
129 number<NumWarps>{}, // m2
130 number<LanesPerK>{}, // k0
131 number<KVector>{}), // k1
132 make_tuple(number<NumWarps*(WarpSize * KVector + KPad)>{}, // m0
133 number<Block_K>{}, // m1
135 number<KVector>{}, // k0
136 number<1>{}), // k1
137 number<KVector>{}, // lds store vector(actually no explicit store)
138 number<1>{});
139
140 constexpr auto lds_block_desc_issues_warps_lanes = transform_tensor_descriptor(
141 lds_block_desc_0,
146 make_tuple(sequence<0>{}, sequence<2>{}, sequence<1, 3, 4>{}),
147 make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
148
149 return lds_block_desc_issues_warps_lanes;
150 }
151 };
152
153 public:
154 static CK_TILE_DEVICE constexpr auto MakeCBlockDist()
155 {
156 constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
161 sequence<2, 1>, // !! note here is different
163
165
166 constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
167 c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
168 constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
169 return c_block_dstr;
170 }
171
172 static CK_TILE_DEVICE constexpr auto MakeCBlockTile()
173 {
174 using CDataType = float;
175 constexpr auto c_block_dstr = MakeCBlockDist();
176 auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
177 return c_block_tensor;
178 }
179
181 {
182 // A async->LDS
183 constexpr index_t WarpSize = ck_tile::get_warp_size();
184
185 constexpr index_t KPack_ = 8; // GetSmemKPack_A<Problem>(); // LDS
186 constexpr index_t KVector = 2; // GetAlignment_A<Problem>(); // async copy 1 dword
187 constexpr index_t KPad = KPack_; // pad between warps
188
189 static_assert(Block_K % KVector == 0);
190 constexpr index_t LanesPerK = Block_K / KVector; // how many thread loading K
191
192 return LdsStoreDescSelector<LanesPerK, WarpSize>::
193 template MakeDesc<NumWarps, Block_M, Block_K, KVector, KPad>();
194 }
195
196 // template <typename Problem>
198 {
199 // load from LDS to register, every wave has same layout
200 constexpr index_t KPack_ = 8; // GetSmemKPack_A<Problem>(); // LDS
201 constexpr index_t KPad = KPack_; // pad between warps
202
203 constexpr index_t kAMLane = 16;
204 constexpr index_t kABKLane = 4;
205 constexpr index_t kABKPerLane = 4;
206 constexpr index_t kKIter = 2;
207 static_assert(KPack_ == (kABKPerLane * kKIter));
208
209 constexpr auto lds_block_desc_0 =
211 number<kAMLane>{}, // m1 p
212 number<Repeat_K>{}, // k0 y
213 number<kABKLane>{}, // k1 p
214 number<KPack_>{}), // k2 y-vector
215 make_tuple(number<kAMLane*(Block_K + KPad)>{}, // m0
218 number<KPack_>{}, // k1
219 number<1>{}), // k2
220 number<KPack_>{}, // lds load vector
221 number<1>{});
222
223 constexpr auto lds_desc_m_k = transform_tensor_descriptor(
224 lds_block_desc_0,
230
231 return lds_desc_m_k;
232 }
233
234 static constexpr auto GetGemm_AWarpEnc()
235 {
236 constexpr index_t kAMLane = 16;
237 constexpr index_t kABKLane = 4;
238 constexpr index_t kABKPerLane = 4;
239 constexpr index_t kKIter = 2;
240
241 using enc_ = tile_distribution_encoding<
248 return enc_{};
249 }
250
252 {
253 // return 32 * (128 + 8) * sizeof(bf16_t);
254 return MakeLdsLoadDesc_A().get_element_space_size() * sizeof(bf16_t) * 2; // 2 lds buffers
255 }
256};
257
258// clang-format off
259#define _EXPAND_ASM_ARGS_OUT_ONE_ACC \
260 [s_loop_cnt]"+s"(loop_cnt), \
261 [v_acc_0]"+v"(v_acc[0]), \
262 [v_acc_1]"+v"(v_acc[1]), \
263 [v_acc_2]"+v"(v_acc[2]), \
264 [v_acc_3]"+v"(v_acc[3]), \
265 [v_acc_4]"+v"(v_acc[4]), \
266 [v_acc_5]"+v"(v_acc[5]), \
267 [v_acc_6]"+v"(v_acc[6]), \
268 [v_acc_7]"+v"(v_acc[7]), \
269 [v_acc_8]"+v"(v_acc[8]), \
270 [v_acc_9]"+v"(v_acc[9]), \
271 [v_acc_10]"+v"(v_acc[10]), \
272 [v_acc_11]"+v"(v_acc[11]), \
273 [v_acc_12]"+v"(v_acc[12]), \
274 [v_acc_13]"+v"(v_acc[13]), \
275 [v_acc_14]"+v"(v_acc[14]), \
276 [v_acc_15]"+v"(v_acc[15]), \
277 [s_mem_]"+r"(smem)
278
279#define _EXPAND_ASM_ARGS_OUT_TWO_ACC \
280 [s_loop_cnt]"+s"(loop_cnt), \
281 [v_acc_0]"+v"(v_acc[0]), \
282 [v_acc_1]"+v"(v_acc[1]), \
283 [v_acc_2]"+v"(v_acc[2]), \
284 [v_acc_3]"+v"(v_acc[3]), \
285 [v_acc_4]"+v"(v_acc[4]), \
286 [v_acc_5]"+v"(v_acc[5]), \
287 [v_acc_6]"+v"(v_acc[6]), \
288 [v_acc_7]"+v"(v_acc[7]), \
289 [v_acc_8]"+v"(v_acc[8]), \
290 [v_acc_9]"+v"(v_acc[9]), \
291 [v_acc_10]"+v"(v_acc[10]), \
292 [v_acc_11]"+v"(v_acc[11]), \
293 [v_acc_12]"+v"(v_acc[12]), \
294 [v_acc_13]"+v"(v_acc[13]), \
295 [v_acc_14]"+v"(v_acc[14]), \
296 [v_acc_15]"+v"(v_acc[15]), \
297 [v_acc_16]"+v"(v_acc[16]), \
298 [v_acc_17]"+v"(v_acc[17]), \
299 [v_acc_18]"+v"(v_acc[18]), \
300 [v_acc_19]"+v"(v_acc[19]), \
301 [v_acc_20]"+v"(v_acc[20]), \
302 [v_acc_21]"+v"(v_acc[21]), \
303 [v_acc_22]"+v"(v_acc[22]), \
304 [v_acc_23]"+v"(v_acc[23]), \
305 [v_acc_24]"+v"(v_acc[24]), \
306 [v_acc_25]"+v"(v_acc[25]), \
307 [v_acc_26]"+v"(v_acc[26]), \
308 [v_acc_27]"+v"(v_acc[27]), \
309 [v_acc_28]"+v"(v_acc[28]), \
310 [v_acc_29]"+v"(v_acc[29]), \
311 [v_acc_30]"+v"(v_acc[30]), \
312 [v_acc_31]"+v"(v_acc[31]), \
313 [s_mem_]"+r"(smem)
314
315#define _EXPAND_ASM_ARGS_IN \
316 [s_res_a0]"s"(res_a[0]), \
317 [s_res_a1]"s"(res_a[1]), \
318 [s_res_a2]"s"(res_a[2]), \
319 [s_res_a3]"s"(res_a[3]), \
320 [s_res_b0]"s"(res_b[0]), \
321 [s_res_b1]"s"(res_b[1]), \
322 [s_res_b2]"s"(res_b[2]), \
323 [s_res_b3]"s"(res_b[3]), \
324 [v_os_a0]"v"(static_cast<index_t>(cached_coords_a[number<0>{}] * sizeof(ADataType))), \
325 [v_os_a1]"v"(static_cast<index_t>(cached_coords_a[number<1>{}] * sizeof(ADataType))), \
326 [v_os_a2]"v"(static_cast<index_t>(cached_coords_a[number<2>{}] * sizeof(ADataType))), \
327 [v_os_a3]"v"(static_cast<index_t>(cached_coords_a[number<3>{}] * sizeof(ADataType))), \
328 [v_os_a4]"v"(static_cast<index_t>(cached_coords_a[number<4>{}] * sizeof(ADataType))), \
329 [v_os_a5]"v"(static_cast<index_t>(cached_coords_a[number<5>{}] * sizeof(ADataType))), \
330 [v_os_a6]"v"(static_cast<index_t>(cached_coords_a[number<6>{}] * sizeof(ADataType))), \
331 [v_os_a7]"v"(static_cast<index_t>(cached_coords_a[number<7>{}] * sizeof(ADataType))), \
332 \
333 [v_os_b0]"v"(static_cast<index_t>(cached_coords_b[number<0>{}] * sizeof(BDataType))), \
334 [v_os_b1]"v"(static_cast<index_t>(cached_coords_b[number<1>{}] * sizeof(BDataType))), \
335 [v_os_b2]"v"(static_cast<index_t>(cached_coords_b[number<2>{}] * sizeof(BDataType))), \
336 [v_os_b3]"v"(static_cast<index_t>(cached_coords_b[number<3>{}] * sizeof(BDataType))), \
337 [v_os_b4]"v"(static_cast<index_t>(cached_coords_b[number<4>{}] * sizeof(BDataType))), \
338 [v_os_b5]"v"(static_cast<index_t>(cached_coords_b[number<5>{}] * sizeof(BDataType))), \
339 [v_os_b6]"v"(static_cast<index_t>(cached_coords_b[number<6>{}] * sizeof(BDataType))), \
340 [v_os_b7]"v"(static_cast<index_t>(cached_coords_b[number<7>{}] * sizeof(BDataType))), \
341 \
342 [v_os_slda]"v"(static_cast<index_t>(a_sld.cached_coords_[number<0>{}].get_offset() * sizeof(ADataType))),\
343 [s_m0_init]"s"(m0_init_value), \
344 [s_size_per_issue]"s"(size_per_issue), \
345 [smem_sz]"n"(smem_buf_size), \
346 [sld_os_0]"n"(sld_os[number<0>{}].value), \
347 [sld_os_1]"n"(sld_os[number<1>{}].value), \
348 [sld_os_2]"n"(sld_os[number<2>{}].value), \
349 [sld_os_3]"n"(sld_os[number<3>{}].value), \
350 [sld_os_4]"n"(sld_os[number<4>{}].value), \
351 [sld_os_5]"n"(sld_os[number<5>{}].value), \
352 [sld_os_6]"n"(sld_os[number<6>{}].value), \
353 [sld_os_7]"n"(sld_os[number<7>{}].value), \
354 [s_tile_os_a]"s"(tile_offset_a_bytes), \
355 [s_tile_os_b]"s"(tile_offset_b_bytes)
356
357#define _EXPAND_ASM_ARGS_CLOBBER \
358 "memory", "a0", "a1", "a2", "a3", "a4", "a5", "a6", "a7", "a8", "a9", \
359 "a10", "a11", "a12", "a13", "a14", "a15", "a16", "a17", "a18", "a19", \
360 "a20", "a21", "a22", "a23", "a24", "a25", "a26", "a27", "a28", "a29", \
361 "a30", "a31", "a32", "a33", "a34", "a35", "a36", "a37", "a38", "a39", \
362 "a40", "a41", "a42", "a43", "a44", "a45", "a46", "a47", "a48", "a49", \
363 "a50", "a51", "a52", "a53", "a54", "a55", "a56", "a57", "a58", "a59", \
364 "a60", "a61", "a62", "a63", "a64", "a65", "a66", "a67", "a68", "a69", \
365 "a70", "a71", "a72", "a73", "a74", "a75", "a76", "a77", "a78", "a79", \
366 "a80", "a81", "a82", "a83", "a84", "a85", "a86", "a87", "a88", "a89", \
367 "a90", "a91", "a92", "a93", "a94", "a95", "a96", "a97", "a98", "a99", \
368 "a100", "a101", "a102", "a103", "a104", "a105", "a106", "a107", \
369 "a108", "a109", "a110", "a111", "a112", "a113", "a114", "a115", \
370 "a116", "a117", "a118", "a119", "a120", "a121", "a122", "a123", \
371 "a124", "a125", "a126", "a127", "a128", "a129", "a130", "a131", \
372 "a132", "a133", "a134", "a135", "a136", "a137", "a138", "a139", \
373 "a140", "a141", "a142", "a143", "a144", "a145", "a146", "a147", \
374 "a148", "a149", "a150", "a151", "a152", "a153", "a154", "a155", \
375 "a156", "a157", "a158", "a159", "a160", "a161", "a162", "a163", \
376 "a164", "a165", "a166", "a167", "a168", "a169", "a170", "a171", \
377 "a172", "a173", "a174", "a175", "a176", "a177", "a178", "a179", \
378 "a180", "a181", "a182", "a183", "a184", "a185", "a186", "a187", \
379 "a188", "a189", "a190", "a191", "a192", "a193", "a194", "a195", \
380 "a196", "a197", "a198", "a199", "a200", "a201", "a202", "a203", \
381 "a204", "a205", "a206", "a207", "a208", "a209", "a210", "a211", \
382 "a212", "a213", "a214", "a215", "a216", "a217", "a218", "a219", \
383 "a220", "a221", "a222", "a223", "a224", "a225", "a226", "a227", \
384 "a228", "a229", "a230", "a231", "a232", "a233", "a234", "a235", \
385 "a236", "a237", "a238", "a239", "a240", "a241", "a242", "a243", \
386 "a244", "a245", "a246", "a247", "a248", "a249", "a250", "a251", \
387 "a252", "a253", "a254", "a255", \
388 "s16", "s17", "s18", "s19", "s20", "s21", "s22", "s23", \
389 "s86", \
390 "v64", "v65", "v66", "v67", "v68", "v69", \
391 "v70", "v71", "v72", "v73", "v74", "v75", "v76", "v77", "v78", "v79", \
392 "v80", "v81", "v82", "v83", "v84", "v85", "v86", "v87", "v88", "v89", \
393 "v90", "v91", "v92", "v93", "v94", "v95", "v96", "v97", "v98", "v99", \
394 "v100", "v101", "v102", "v103", "v104", "v105", "v106", "v107", \
395 "v108", "v109", "v110", "v111", "v112", "v113", "v114", "v115", \
396 "v116", "v117", "v118", "v119", "v120", "v121", "v122", "v123", \
397 "v124", "v125", "v126", "v127"
398// clang-format on
399
401{
404
405 // TODO: need paired with tile_window_linear!
406 // TODO: need call init_raw() before call this function!
407 // Is2B: originally for B matrix we have 2 prefetch buffers. If set this to true
408 // we can support A matric serve 2 B matrix, B0/B1, each B0/B1 still have same tile size
409 template <typename ARes, typename ACoords, typename BRes, typename BCoords, bool Is2B = false>
410 CK_TILE_DEVICE auto
411 operator()(const ARes& res_a,
412 const ACoords& cached_coords_a,
413 const BRes& res_b,
414 const BCoords& cached_coords_b,
415 CK_TILE_LDS_ADDR void* smem,
416 index_t k,
417 index_t tile_offset_a, // for each tile, the offset to move for each unroll
418 index_t tile_offset_b,
419 bool_constant<Is2B> = {}) // for each tile, the offset to move for each unroll
420 {
421 static_assert(ACoords::size() == Block_M * Block_K / BlockSize / 2 /*2x per dword*/); // 8
422 static_assert(BCoords::size() == Repeat_N);
423
424 auto a_sst = make_tile_window(
426 reinterpret_cast<CK_TILE_LDS_ADDR ADataType*>(smem), MakeLdsStoreDesc_A()),
427 MakeLdsStoreDesc_A().get_lengths(),
428 {0, 0, 0});
429
430 auto a_sld = [&]() {
431 constexpr auto a_warp_enc_ = GetGemm_AWarpEnc();
432 constexpr auto a_outer_dstr_enc = tile_distribution_encoding<
433 sequence<WarpPerBlock_N>,
434 tuple<sequence<Repeat_M, WarpPerBlock_M>, sequence<Repeat_K>>,
435 tuple<sequence<1, 0>>,
436 tuple<sequence<1, 0>>,
437 sequence<1, 2>,
438 sequence<0, 0>>{};
439 constexpr auto a_block_dstr_encode =
440 detail::make_embed_tile_distribution_encoding(a_outer_dstr_enc, a_warp_enc_);
443 reinterpret_cast<CK_TILE_LDS_ADDR ADataType*>(smem), MakeLdsLoadDesc_A()),
444 MakeLdsLoadDesc_A().get_lengths(),
445 {0, 0},
446 make_static_tile_distribution(a_block_dstr_encode));
447 }();
448
449 const index_t tile_offset_a_bytes = tile_offset_a * sizeof(ADataType);
450 const index_t tile_offset_b_bytes = tile_offset_b * sizeof(BDataType);
451
452 const auto [m0_init_value, size_per_issue] = get_async_store_smem_info(a_sst);
453 constexpr auto smem_buf_size =
454 MakeLdsLoadDesc_A().get_element_space_size() * sizeof(ADataType);
455 static_assert(a_sld.get_num_of_access() == 8);
456 constexpr auto sld_os = generate_tuple(
457 [&](auto i_access) {
458 return number<a_sld.get_bottom_linear_offset(i_access) * sizeof(ADataType)>{};
459 },
460 number<a_sld.get_num_of_access()>{});
461
462 index_t loop_cnt = k / Block_K;
463
464 if constexpr(Is2B)
465 {
466 // this is the acc thread buffer
467 fp32x4_t v_acc[32]{.0f};
468
469 // B nr->kr
470#pragma clang diagnostic push
471#pragma clang diagnostic ignored "-Winline-asm"
472 // clang-format off
473 asm volatile(
474#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_BF16
475#define CK_TILE_FLATMM_UK_2B 1
479 [s_res_b4]"s"(res_b[4]),
480 [s_res_b5]"s"(res_b[5]),
481 [s_res_b6]"s"(res_b[6]),
482 [s_res_b7]"s"(res_b[7])
483 : _EXPAND_ASM_ARGS_CLOBBER, "s24", "s25", "s26", "s27"
484 );
485 // clang-format on
486#pragma clang diagnostic pop
487
488 // return local scratch
490 for(auto i = 0; i < 16; i++)
491 {
492 c.at(number<0>{}).get_thread_buffer()[4 * i + 0] = v_acc[i].x;
493 c.at(number<0>{}).get_thread_buffer()[4 * i + 1] = v_acc[i].y;
494 c.at(number<0>{}).get_thread_buffer()[4 * i + 2] = v_acc[i].z;
495 c.at(number<0>{}).get_thread_buffer()[4 * i + 3] = v_acc[i].w;
496 }
497 for(auto i = 0; i < 16; i++)
498 {
499 c.at(number<1>{}).get_thread_buffer()[4 * i + 0] = v_acc[16 + i].x;
500 c.at(number<1>{}).get_thread_buffer()[4 * i + 1] = v_acc[16 + i].y;
501 c.at(number<1>{}).get_thread_buffer()[4 * i + 2] = v_acc[16 + i].z;
502 c.at(number<1>{}).get_thread_buffer()[4 * i + 3] = v_acc[16 + i].w;
503 }
504 return c;
505 }
506 else
507 {
508 // this is the acc thread buffer
509 fp32x4_t v_acc[16]{.0f};
510
511 // B nr->kr
512#pragma clang diagnostic push
513#pragma clang diagnostic ignored "-Winline-asm"
514 // clang-format off
515 asm volatile(
516#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_BF16
521 );
522 // clang-format on
523#pragma clang diagnostic pop
524
525 // return local scratch
526 auto c = MakeCBlockTile();
527 for(auto i = 0; i < 16; i++)
528 {
529 c.get_thread_buffer()[4 * i + 0] = v_acc[i].x;
530 c.get_thread_buffer()[4 * i + 1] = v_acc[i].y;
531 c.get_thread_buffer()[4 * i + 2] = v_acc[i].z;
532 c.get_thread_buffer()[4 * i + 3] = v_acc[i].w;
533 }
534 return c;
535 }
536 }
537};
538
540{
543
544 // TODO: need paired with tile_window_linear!
545 // TODO: need call init_raw() before call this function!
546 template <typename ARes, typename ACoords, typename BRes, typename BCoords, bool Is2B = false>
547 CK_TILE_DEVICE auto
548 operator()(const ARes& res_a,
549 const ACoords& cached_coords_a,
550 const BRes& res_b,
551 const BCoords& cached_coords_b,
552 CK_TILE_LDS_ADDR void* smem,
553 index_t k,
554 index_t tile_offset_a, // for each tile, the offset to move for each unroll
555 index_t tile_offset_b, // for each tile, the offset to move for each unroll
557 {
558 static_assert(ACoords::size() == Block_M * Block_K / BlockSize / 2 /*2x per dword*/); // 8
559 static_assert(BCoords::size() == Repeat_N);
560
561 auto a_sst = make_tile_window(
563 reinterpret_cast<CK_TILE_LDS_ADDR ADataType*>(smem), MakeLdsStoreDesc_A()),
564 MakeLdsStoreDesc_A().get_lengths(),
565 {0, 0, 0});
566
567 auto a_sld = [&]() {
568 constexpr auto a_warp_enc_ = GetGemm_AWarpEnc();
569 constexpr auto a_outer_dstr_enc = tile_distribution_encoding<
570 sequence<WarpPerBlock_N>,
571 tuple<sequence<Repeat_M, WarpPerBlock_M>, sequence<Repeat_K>>,
572 tuple<sequence<1, 0>>,
573 tuple<sequence<1, 0>>,
574 sequence<1, 2>,
575 sequence<0, 0>>{};
576 constexpr auto a_block_dstr_encode =
577 detail::make_embed_tile_distribution_encoding(a_outer_dstr_enc, a_warp_enc_);
580 reinterpret_cast<CK_TILE_LDS_ADDR ADataType*>(smem), MakeLdsLoadDesc_A()),
581 MakeLdsLoadDesc_A().get_lengths(),
582 {0, 0},
583 make_static_tile_distribution(a_block_dstr_encode));
584 }();
585
586 const index_t tile_offset_a_bytes = tile_offset_a * sizeof(ADataType);
587 const index_t tile_offset_b_bytes = tile_offset_b * sizeof(BDataType);
588
589 const auto [m0_init_value, size_per_issue] = get_async_store_smem_info(a_sst);
590 constexpr auto smem_buf_size =
591 MakeLdsLoadDesc_A().get_element_space_size() * sizeof(ADataType);
592 static_assert(a_sld.get_num_of_access() == 8);
593 constexpr auto sld_os = generate_tuple(
594 [&](auto i_access) {
595 return number<a_sld.get_bottom_linear_offset(i_access) * sizeof(ADataType)>{};
596 },
597 number<a_sld.get_num_of_access()>{});
598
599 index_t loop_cnt = k / Block_K;
600
601 if constexpr(Is2B)
602 {
603 // this is the acc thread buffer
604 fp32x4_t v_acc[32]{.0f};
605
606 // B nr->kr
607#pragma clang diagnostic push
608#pragma clang diagnostic ignored "-Winline-asm"
609 // clang-format off
610 asm volatile(
611#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_FP16
612#define CK_TILE_FLATMM_UK_2B 1
616 [s_res_b4]"s"(res_b[4]),
617 [s_res_b5]"s"(res_b[5]),
618 [s_res_b6]"s"(res_b[6]),
619 [s_res_b7]"s"(res_b[7])
620 : _EXPAND_ASM_ARGS_CLOBBER, "s24", "s25", "s26", "s27"
621 );
622 // clang-format on
623#pragma clang diagnostic pop
624
625 // return local scratch
627 for(auto i = 0; i < 16; i++)
628 {
629 c.at(number<0>{}).get_thread_buffer()[4 * i + 0] = v_acc[i].x;
630 c.at(number<0>{}).get_thread_buffer()[4 * i + 1] = v_acc[i].y;
631 c.at(number<0>{}).get_thread_buffer()[4 * i + 2] = v_acc[i].z;
632 c.at(number<0>{}).get_thread_buffer()[4 * i + 3] = v_acc[i].w;
633 }
634 for(auto i = 0; i < 16; i++)
635 {
636 c.at(number<1>{}).get_thread_buffer()[4 * i + 0] = v_acc[16 + i].x;
637 c.at(number<1>{}).get_thread_buffer()[4 * i + 1] = v_acc[16 + i].y;
638 c.at(number<1>{}).get_thread_buffer()[4 * i + 2] = v_acc[16 + i].z;
639 c.at(number<1>{}).get_thread_buffer()[4 * i + 3] = v_acc[16 + i].w;
640 }
641 return c;
642 }
643 else
644 {
645 // this is the acc thread buffer
646 fp32x4_t v_acc[16]{.0f};
647
648 // B nr->kr
649#pragma clang diagnostic push
650#pragma clang diagnostic ignored "-Winline-asm"
651 // clang-format off
652 asm volatile(
653#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_FP16
658 );
659 // clang-format on
660#pragma clang diagnostic pop
661
662 // return local scratch
663 auto c = MakeCBlockTile();
664 for(auto i = 0; i < 16; i++)
665 {
666 c.get_thread_buffer()[4 * i + 0] = v_acc[i].x;
667 c.get_thread_buffer()[4 * i + 1] = v_acc[i].y;
668 c.get_thread_buffer()[4 * i + 2] = v_acc[i].z;
669 c.get_thread_buffer()[4 * i + 3] = v_acc[i].w;
670 }
671 return c;
672 }
673 }
674};
675#undef _EXPAND_ASM_ARGS_OUT_ONE_ACC
676#undef _EXPAND_ASM_ARGS_OUT_TWO_ACC
677#undef _EXPAND_ASM_ARGS_IN
678#undef _EXPAND_ASM_ARGS_CLOBBER
679} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_LDS_ADDR
Definition config.hpp:58
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
#define _EXPAND_ASM_ARGS_OUT_TWO_ACC
Definition flatmm_32x512x128_1x4x1_16x16x32.hpp:279
#define _EXPAND_ASM_ARGS_CLOBBER
Definition flatmm_32x512x128_1x4x1_16x16x32.hpp:357
#define _EXPAND_ASM_ARGS_IN
Definition flatmm_32x512x128_1x4x1_16x16x32.hpp:315
#define _EXPAND_ASM_ARGS_OUT_ONE_ACC
Definition flatmm_32x512x128_1x4x1_16x16x32.hpp:259
CK_TILE_HOST_DEVICE constexpr auto make_embed_tile_distribution_encoding(OuterDstr, InnerDstr)
Definition tile_distribution_encoding.hpp:457
Definition tile/core/algorithm/cluster_descriptor.hpp:13
CK_TILE_HOST_DEVICE constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition coordinate_transform.hpp:1558
CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
Definition arch.hpp:63
CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType *__restrict__ p, const tensor_descriptor< Ts... > &desc)
Definition tensor_view.hpp:452
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition tile/core/tensor/tensor_descriptor.hpp:274
constant< b > bool_constant
Definition tile/core/numeric/integral_constant.hpp:43
CK_TILE_HOST_DEVICE constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition coordinate_transform.hpp:1615
bfloat16_t bf16_t
Definition bfloat16.hpp:113
_Float16 fp16_t
Definition half.hpp:110
CK_TILE_HOST_DEVICE constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldTopIdss, NewUpperDimensionNewTopIdss)
Definition tile/core/tensor/tensor_descriptor.hpp:203
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution &)
Definition static_distributed_tensor.hpp:142
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
CK_TILE_HOST_DEVICE constexpr auto generate_tuple(F &&f, number< N >)
Definition tile/core/container/tuple.hpp:429
CK_TILE_DEVICE constexpr auto make_tile_window_linear(const TensorView_ &tensor_view, const WindowLengths_ &window_lengths, const multi_index< TensorView_::get_num_of_dimension()> &origin, const StaticTileDistribution_ &tile_distribution, LinearBottomDims_={})
Definition tile_window_linear.hpp:993
CK_TILE_DEVICE auto get_async_store_smem_info(LdsTileWindow_ &&lds_tile)
Definition tile_window_utils.hpp:31
float fp32x4_t
Definition vector_type.hpp:128
WarpGemmImpl< WarpGemmAttributeMfmaIterateKAndTransposedCDistribution< WarpGemmAttributeMfmaImplF16F16F32M16N16K16< WGAttrCtlEnum::Default_ >, 2, AttrNumAccess > > WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution
Definition warp_gemm.hpp:106
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition tile_distribution.hpp:480
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
typename std::enable_if< B, T >::type enable_if_t
Definition enable_if.hpp:27
STL namespace.
Definition flatmm_32x512x128_1x4x1_16x16x32.hpp:401
bf16_t ADataType
Definition flatmm_32x512x128_1x4x1_16x16x32.hpp:402
bf16_t BDataType
Definition flatmm_32x512x128_1x4x1_16x16x32.hpp:403
CK_TILE_DEVICE auto operator()(const ARes &res_a, const ACoords &cached_coords_a, const BRes &res_b, const BCoords &cached_coords_b, CK_TILE_LDS_ADDR void *smem, index_t k, index_t tile_offset_a, index_t tile_offset_b, bool_constant< Is2B >={})
Definition flatmm_32x512x128_1x4x1_16x16x32.hpp:411
Definition flatmm_32x512x128_1x4x1_16x16x32.hpp:38
static constexpr index_t NumWarps
Definition flatmm_32x512x128_1x4x1_16x16x32.hpp:47
static CK_TILE_DEVICE constexpr auto MakeCBlockDist()
Definition flatmm_32x512x128_1x4x1_16x16x32.hpp:154
static constexpr index_t WarpPerBlock_N
Definition flatmm_32x512x128_1x4x1_16x16x32.hpp:44
static constexpr index_t SubKPacks
Definition flatmm_32x512x128_1x4x1_16x16x32.hpp:55
static constexpr index_t WarpPerBlock_K
Definition flatmm_32x512x128_1x4x1_16x16x32.hpp:45
static constexpr index_t Block_K
Definition flatmm_32x512x128_1x4x1_16x16x32.hpp:41
static CK_TILE_HOST_DEVICE constexpr auto MakeLdsLoadDesc_A()
Definition flatmm_32x512x128_1x4x1_16x16x32.hpp:197
static constexpr index_t Block_N
Definition flatmm_32x512x128_1x4x1_16x16x32.hpp:40
static constexpr index_t Block_W
Definition flatmm_32x512x128_1x4x1_16x16x32.hpp:58
static constexpr auto GetGemm_AWarpEnc()
Definition flatmm_32x512x128_1x4x1_16x16x32.hpp:234
static constexpr index_t Warp_M
Definition flatmm_32x512x128_1x4x1_16x16x32.hpp:49
static constexpr index_t Block_M
Definition flatmm_32x512x128_1x4x1_16x16x32.hpp:39
static constexpr index_t Block_Kr
Definition flatmm_32x512x128_1x4x1_16x16x32.hpp:60
static constexpr index_t Repeat_K
Definition flatmm_32x512x128_1x4x1_16x16x32.hpp:64
static constexpr index_t Warp_N
Definition flatmm_32x512x128_1x4x1_16x16x32.hpp:50
static constexpr index_t Block_Nr
Definition flatmm_32x512x128_1x4x1_16x16x32.hpp:59
static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t GetSmemSize()
Definition flatmm_32x512x128_1x4x1_16x16x32.hpp:251
static constexpr index_t Warp_K
Definition flatmm_32x512x128_1x4x1_16x16x32.hpp:51
static constexpr index_t Repeat_M
Definition flatmm_32x512x128_1x4x1_16x16x32.hpp:62
static constexpr index_t BlockSize
Definition flatmm_32x512x128_1x4x1_16x16x32.hpp:53
static constexpr index_t WarpPerBlock_M
Definition flatmm_32x512x128_1x4x1_16x16x32.hpp:43
static CK_TILE_HOST_DEVICE constexpr auto MakeLdsStoreDesc_A()
Definition flatmm_32x512x128_1x4x1_16x16x32.hpp:180
static CK_TILE_DEVICE constexpr auto MakeCBlockTile()
Definition flatmm_32x512x128_1x4x1_16x16x32.hpp:172
static constexpr index_t Repeat_N
Definition flatmm_32x512x128_1x4x1_16x16x32.hpp:63
Definition flatmm_32x512x128_1x4x1_16x16x32.hpp:540
CK_TILE_DEVICE auto operator()(const ARes &res_a, const ACoords &cached_coords_a, const BRes &res_b, const BCoords &cached_coords_b, CK_TILE_LDS_ADDR void *smem, index_t k, index_t tile_offset_a, index_t tile_offset_b, bool_constant< Is2B >={})
Definition flatmm_32x512x128_1x4x1_16x16x32.hpp:548
fp16_t ADataType
Definition flatmm_32x512x128_1x4x1_16x16x32.hpp:541
fp16_t BDataType
Definition flatmm_32x512x128_1x4x1_16x16x32.hpp:542
Definition tile_distribution_encoding.hpp:26
Definition tile/core/container/tuple.hpp:192