blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v3.hpp Source File

blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v3.hpp Source File#

Composable Kernel: blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v3.hpp Source File
blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v3.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
7
8namespace ck {
9
10// Compute optimized pipeline
11// GlobalPrefetchStages: 2
12// LocalPreFillStages: 1
13// LocalPreFetchStages: 1
14// LocalSharedMemoryBuffer: 1
15
16template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
17 index_t BlockSize,
18 typename ADataType,
19 typename BDataType,
20 typename ComputeDataType,
21 typename AccDataType,
22 typename ATileDesc,
23 typename BTileDesc,
24 typename AMmaTileDesc,
25 typename BMmaTileDesc,
26 index_t ABlockTransferSrcScalarPerVector,
27 index_t BBlockTransferSrcScalarPerVector,
28 index_t MPerBlock,
29 index_t NPerBlock,
30 index_t KPerBlock,
31 index_t MPerXDL,
32 index_t NPerXDL,
33 index_t MRepeat,
34 index_t NRepeat,
35 index_t KPacks>
39
40template <index_t BlockSize,
41 typename ADataType,
42 typename BDataType,
43 typename ComputeDataType,
44 typename AccDataType,
45 typename ATileDesc,
46 typename BTileDesc,
47 typename AMmaTileDesc,
48 typename BMmaTileDesc,
49 index_t ABlockTransferSrcScalarPerVector,
50 index_t BBlockTransferSrcScalarPerVector,
51 index_t MPerBlock,
52 index_t NPerBlock,
53 index_t KPerBlock,
54 index_t MPerXDL,
55 index_t NPerXDL,
56 index_t MRepeat,
57 index_t NRepeat,
58 index_t KPack
59 // ,bool TransposeC //disable transposec right now...
60 >
62 BlockSize,
63 ADataType,
64 BDataType,
65 ComputeDataType,
66 AccDataType,
67 ATileDesc,
68 BTileDesc,
69 AMmaTileDesc,
70 BMmaTileDesc,
71 ABlockTransferSrcScalarPerVector,
72 BBlockTransferSrcScalarPerVector,
73 MPerBlock,
74 NPerBlock,
75 KPerBlock,
76 MPerXDL,
77 NPerXDL,
78 MRepeat,
79 NRepeat,
80 KPack>
82 ADataType,
83 BDataType,
84 ComputeDataType,
85 AccDataType,
86 ATileDesc,
87 BTileDesc,
88 AMmaTileDesc,
89 BMmaTileDesc,
90 ABlockTransferSrcScalarPerVector,
91 BBlockTransferSrcScalarPerVector,
92 MPerBlock,
93 NPerBlock,
94 KPerBlock,
95 MPerXDL,
96 NPerXDL,
97 MRepeat,
98 NRepeat,
99 KPack>
100
101{
103 ADataType,
104 BDataType,
105 ComputeDataType,
106 AccDataType,
107 ATileDesc,
108 BTileDesc,
109 AMmaTileDesc,
110 BMmaTileDesc,
111 ABlockTransferSrcScalarPerVector,
112 BBlockTransferSrcScalarPerVector,
113 MPerBlock,
114 NPerBlock,
115 KPerBlock,
116 MPerXDL,
117 NPerXDL,
118 MRepeat,
119 NRepeat,
120 KPack>;
121 using Base::A_K1;
122 using Base::B_K1;
123 using Base::I0;
124 using Base::I1;
125 using Base::I2;
126 using Base::KGroup;
127 using Base::KRepeat;
128 using Base::xdlops_gemm;
129 using typename Base::HotLoopInstList;
130
143
144 using Base::AMmaKStride;
145 using Base::BMmaKStride;
146
147 using Base::MWaves;
148 using Base::WaveSize;
149
150 static constexpr index_t PrefetchStages = 2;
151 static constexpr index_t PrefillStages = 1;
152 static constexpr index_t GlobalBufferNum = 1;
153 static constexpr index_t HotloopLocalBufSwitch = MRepeat % 2 == 0 ? 0 : 1;
154
155 template <typename TileDesc_M0_M1_M2_K>
156 __host__ __device__ static constexpr auto MakeAGemmMmaTileDescriptor(const TileDesc_M0_M1_M2_K&)
157 {
158 constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{});
159 constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{});
160 constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{});
161 constexpr index_t K2 = KPack / KGroup;
162 constexpr index_t K1 = WaveSize / NPerXDL;
163 constexpr index_t K0 = KRepeat * KGroup;
164
166 TileDesc_M0_M1_M2_K{},
174 }
175
176 static constexpr auto a_block_desc_m0_m1_m2_k0_k1_k2 =
178
179 __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
180 {
181 return num_loop > PrefetchStages;
182 }
183
184 __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
185 {
186 return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd;
187 }
188
189 __device__ static constexpr auto HotLoopScheduler()
190 {
191 // A/B split schedule
192 // compiler is likely to use ds_read2 when instruction width smaller than 16bytes
193 constexpr auto num_ds_read_inst_a =
194 HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16
197
198 constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num;
199
200 constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
201 constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num * 2;
202
203 static_assert(num_buffer_load_inst_a == num_ds_write_inst_a);
204
205 constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num * 2;
206 constexpr auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle;
207
208 constexpr auto ds_read_a_issue_cycle =
209 HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4;
210 constexpr auto ds_read_a_mfma_rate =
211 math::integer_divide_ceil(mfma_cycle - 4, 2 * ds_read_a_issue_cycle);
212
213 // constexpr auto num_dsread_a_mfma =
214 // (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate;
215
216 constexpr auto num_total_stages = MRepeat;
217
218 // Group num_mfma_perstage num_ds_read_a_perstage
219 // since we want to reuse a local register buffer
220 constexpr auto num_mfma_perstage = num_mfma_inst / num_total_stages;
221 constexpr auto num_ds_read_a_perstage = num_ds_read_inst_a / num_total_stages;
222
223 constexpr auto num_ds_read_a_mfma_perstage =
224 math::integer_divide_ceil(num_ds_read_a_perstage, ds_read_a_mfma_rate);
225
226 constexpr auto num_ds_read_a_prefetch_stages = 2;
227
228 constexpr auto buffer_load_perstage_more = math::integer_divide_ceil(
229 (num_buffer_load_inst_a + num_buffer_load_inst_b), (num_total_stages - 2));
230 constexpr auto buffer_load_perstage_less = math::integer_divide_floor(
231 (num_buffer_load_inst_a + num_buffer_load_inst_b), (num_total_stages - 2));
232
233 constexpr auto buffer_load_stages_more =
234 (num_buffer_load_inst_a + num_buffer_load_inst_b) -
235 math::integer_divide_floor((num_buffer_load_inst_a + num_buffer_load_inst_b),
236 (num_total_stages - 2)) *
237 ((num_total_stages - 2));
238
239 constexpr auto buffer_load_b_stages =
240 buffer_load_perstage_more * buffer_load_stages_more > num_buffer_load_inst_b
241 ? num_buffer_load_inst_b / buffer_load_perstage_more
242 : (buffer_load_stages_more +
243 (num_buffer_load_inst_b - buffer_load_perstage_more * buffer_load_stages_more) /
244 buffer_load_perstage_less);
245
246 constexpr auto buffer_load_a_stages =
247 num_total_stages - num_ds_read_a_prefetch_stages - buffer_load_b_stages;
248
249 constexpr auto buffer_load_issue_point_b = 0;
250 constexpr auto buffer_load_issue_point_interval_more =
251 num_mfma_perstage / buffer_load_perstage_more;
252 constexpr auto buffer_load_issue_point_interval_less =
253 num_mfma_perstage / buffer_load_perstage_less;
254 constexpr auto ds_write_issue_point = 0;
255 constexpr auto buffer_load_issue_point_a = num_mfma_perstage >= 3 ? 1 : 0;
256
257 // B global read
259 static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
260 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
261
262 if constexpr(((i < buffer_load_stages_more) &&
263 (imfma % buffer_load_issue_point_interval_more ==
264 buffer_load_issue_point_b)) ||
265 ((i >= buffer_load_stages_more) &&
266 (imfma % buffer_load_issue_point_interval_less ==
267 buffer_load_issue_point_b)))
268 {
269 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
270 }
271
272 if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
273 {
274 __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
275 }
276 });
277 });
278
279 // A global read + A local write
281 static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
282 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
283 if constexpr((((i + buffer_load_b_stages) < buffer_load_stages_more) &&
284 (imfma % buffer_load_issue_point_interval_more ==
285 ds_write_issue_point)) ||
286 (((i + buffer_load_b_stages) >= buffer_load_stages_more) &&
287 (imfma % buffer_load_issue_point_interval_less ==
288 ds_write_issue_point)))
289 {
290 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
291 }
292 if constexpr((((i + buffer_load_b_stages) < buffer_load_stages_more) &&
293 (imfma % buffer_load_issue_point_interval_more ==
294 buffer_load_issue_point_a)) ||
295 (((i + buffer_load_b_stages) >= buffer_load_stages_more) &&
296 (imfma % buffer_load_issue_point_interval_less ==
297 buffer_load_issue_point_a)))
298 {
299 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
300 }
301 if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
302 {
303 __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
304 }
305 });
306 });
307
308 // lds synchronization, prefetch next loop local A
310 ignore = i;
311 static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
312 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
313 if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
314 {
315 __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
316 }
317 });
318 });
319 }
320
321 template <typename Stage>
322 __device__ static constexpr auto EpilogueScheduler_1(Stage stage)
323 {
324 constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num;
325 constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num;
326 constexpr auto num_buffer_load_inst_b =
328
329 constexpr auto num_mfma = HotLoopInstList::C_MFMA_Inst_Num * 2;
330
331 constexpr auto staged_num_ds_read_inst_a = num_ds_read_inst_a / MRepeat;
332 constexpr auto staged_num_mfma = num_mfma / MRepeat;
333
334 constexpr auto staged_num_mfma_per_ds_read_a = staged_num_mfma / staged_num_ds_read_inst_a;
335
336 if constexpr(stage.value == 0)
337 {
338 constexpr auto staged_num_buffer_load_b_per_ds_read_a =
339 num_buffer_load_inst_b / staged_num_ds_read_inst_a;
340 constexpr auto staged_num_mfma_per_buffer_load_b =
341 staged_num_mfma / num_buffer_load_inst_b;
342 // B global
344 ignore = i_inst;
345
347 ignore = ibuf_inst;
348 __builtin_amdgcn_sched_group_barrier(
349 0x008, staged_num_mfma_per_buffer_load_b, 0); // MFMA
350 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
351 });
352
353 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
354 __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
355 __builtin_amdgcn_sched_group_barrier(
356 0x008, staged_num_mfma_per_buffer_load_b - 1, 0); // MFMA
357 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
358 });
359
360 __builtin_amdgcn_sched_barrier(0);
361 }
362 else if constexpr(stage.value == 1)
363 {
364 constexpr auto staged_num_mfma_per_ds_write_a =
365 math::integer_divide_ceil(staged_num_mfma, num_ds_write_inst_a);
366
367 constexpr auto stage_more_mfma =
368 staged_num_mfma - (staged_num_mfma_per_ds_write_a - 1) * num_ds_write_inst_a;
369
370 // A local write
371 static_for<0, num_ds_write_inst_a, 1>{}([&](auto i_inst) {
372 if constexpr(i_inst.value < stage_more_mfma)
373 {
374 if(i_inst.value < staged_num_ds_read_inst_a)
375 {
376 __builtin_amdgcn_sched_group_barrier(
377 0x008, staged_num_mfma_per_ds_write_a - 1, 0); // MFMA
378 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write
379 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
380 __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
381 }
382 else
383 {
384 __builtin_amdgcn_sched_group_barrier(
385 0x008, staged_num_mfma_per_ds_write_a, 0); // MFMA
386 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write
387 }
388 }
389 else
390 {
391 if(i_inst.value < staged_num_ds_read_inst_a)
392 {
393 __builtin_amdgcn_sched_group_barrier(
394 0x008, staged_num_mfma_per_ds_write_a - 2, 0); // MFMA
395 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write
396 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
397 __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
398 }
399 else
400 {
401 __builtin_amdgcn_sched_group_barrier(
402 0x008, staged_num_mfma_per_ds_write_a - 1, 0); // MFMA
403 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write
404 }
405 }
406 });
407 __builtin_amdgcn_sched_barrier(0);
408 }
409 else
410 {
411 // A local Read
413 ignore = i_inst;
414 __builtin_amdgcn_sched_group_barrier(
415 0x008, staged_num_mfma_per_ds_read_a, 0); // MFMA
416 __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
417 });
418
419 __builtin_amdgcn_sched_barrier(0);
420 }
421 }
422
423 __device__ static constexpr auto EpilogueScheduler_2()
424 {
425 constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num;
426
427 constexpr auto num_mfma = HotLoopInstList::C_MFMA_Inst_Num * 2;
428
429 constexpr auto staged_num_ds_read_inst_a = num_ds_read_inst_a / MRepeat;
430 constexpr auto staged_num_mfma = num_mfma / MRepeat;
431
432 constexpr auto staged_num_mfma_per_ds_read_a = staged_num_mfma / staged_num_ds_read_inst_a;
433
434 // A local Read
436 ignore = i_inst;
437 __builtin_amdgcn_sched_group_barrier(0x008, staged_num_mfma_per_ds_read_a, 0); // MFMA
438 __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
439 });
440
441 __builtin_amdgcn_sched_barrier(0);
442 }
443
444 template <bool HasMainLoop,
445 TailNumber TailNum,
446 typename AGridDesc,
447 typename ABlockDesc,
448 typename ABlockTransfer,
449 typename AGridBuffer,
450 typename ABlockBuffer,
451 typename ABlockTransferStep,
452 typename BGridDesc,
453 typename BBlockTransfer,
454 typename BGridBuffer,
455 typename BBlockBuffer,
456 typename BBlockTransferStep,
457 typename CThreadBuffer>
458 __device__ void Run(const AGridDesc& a_grid_desc,
459 const ABlockDesc& a_block_desc,
460 ABlockTransfer& a_blockwise_copy,
461 const AGridBuffer& a_grid_buf,
462 ABlockBuffer& a_block_buf,
463 const ABlockTransferStep& a_block_copy_step,
464 const BGridDesc& b_grid_desc,
465 BBlockTransfer& b_blockwise_copy,
466 BBlockTransfer& b_blockwise_copy_up,
467 const BGridBuffer& b_grid_buf,
468 const BGridBuffer& b_grid_buf_up,
469 BBlockBuffer& b_block_buf,
470 const BBlockTransferStep& b_block_copy_step,
471 CThreadBuffer& c_thread_buf,
472 CThreadBuffer& c_thread_buf_up,
473 index_t num_loop) const
474 {
475 ignore = b_block_buf;
476 __builtin_amdgcn_sched_barrier(0);
478 a_thread_desc_.GetElementSpaceSize());
480 b_thread_desc_.GetElementSpaceSize());
481
482 StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs;
483 StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs_up;
484 constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0);
485
486 // Global prefetch A1 B1
487 b_blockwise_copy.Run(b_grid_desc,
488 b_grid_buf,
490 b_block_origin_idx,
491 b_thread_bufs(I0));
492
493 b_blockwise_copy_up.Run(b_grid_desc,
494 b_grid_buf_up,
496 b_block_origin_idx,
497 b_thread_bufs_up(I0));
498 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
499 b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
500
501 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
502 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
503 __builtin_amdgcn_sched_barrier(0);
504
505 // // Local prefill A1
506 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I0));
507
508 // // Global prefetch A2
509 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
510 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
511
512 // Local prefetch A1
514 static_for<0, 2, 1>{}([&](auto m0) {
515 static_for<0, KRepeat, 1>{}([&](auto k0) {
516 static_for<0, KGroup, 1>{}([&](auto kg0) {
519 a_block_buf.At(I0),
521 make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
522 a_thread_buf);
523 });
524 });
525 });
526
527 // Initialize C
528 c_thread_buf.Clear();
529 c_thread_buf_up.Clear();
530
531 __builtin_amdgcn_sched_barrier(0);
532
533 // main body
534 if constexpr(HasMainLoop)
535 {
536 index_t i = 0;
537 do
538 {
539 auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf) {
540 b_blockwise_copy.Run(b_grid_desc,
541 b_grid_buf,
543 b_block_origin_idx,
544 b_thread_bufs(local_read_buf));
545 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
546 b_blockwise_copy_up.Run(b_grid_desc,
547 b_grid_buf_up,
549 b_block_origin_idx,
550 b_thread_bufs_up(local_read_buf));
551 b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
552
553 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(local_read_buf));
554 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
555 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
556 static_for<0, MRepeat, 1>{}([&](auto m0) {
557 static_for<0, KRepeat, 1>{}([&](auto k0) {
558 static_for<0, NRepeat, 1>{}([&](auto n0) {
562
563 static_for<0, KPack, 1>{}([&](auto ik) {
564 a_thread_vec.template AsType<ComputeDataType>()(ik) =
565 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
566 make_tuple((m0 + HotloopLocalBufSwitch * mfma_reg_buf) %
567 2,
568 I0,
569 I0,
570 k0,
571 I0,
572 ik))>{}];
573 b_thread_vec.template AsType<ComputeDataType>()(ik) =
574 b_thread_bufs[mfma_reg_buf]
575 [Number<b_thread_desc_.CalculateOffset(
576 make_tuple(n0, I0, k0, ik))>{}];
577
578 b_thread_vec_up.template AsType<ComputeDataType>()(ik) =
579 b_thread_bufs_up[mfma_reg_buf]
580 [Number<b_thread_desc_.CalculateOffset(
581 make_tuple(n0, I0, k0, ik))>{}];
582 });
583
584 using mfma_input_type =
585 typename vector_type<ComputeDataType,
586 xdlops_gemm.K1PerXdlops>::type;
587
588 constexpr index_t c_offset =
589 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
590
591 xdlops_gemm.Run(
592 a_thread_vec.template AsType<mfma_input_type>(),
593 b_thread_vec.template AsType<mfma_input_type>(),
594 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
595
596 xdlops_gemm.Run(
597 a_thread_vec.template AsType<mfma_input_type>(),
598 b_thread_vec_up.template AsType<mfma_input_type>(),
599 c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
600 });
601 });
602
603 if constexpr(m0.value == MRepeat - 2)
604 {
606
607 static_for<0, KRepeat, 1>{}([&](auto k0) {
608 static_for<0, KGroup, 1>{}([&](auto kg0) {
609 a_thread_copy_.Run(
611 make_tuple(Number<(m0 + 2) % MRepeat>{},
612 I0,
613 I0,
615 I0,
616 I0),
617 a_block_buf.At(local_read_buf),
620 Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) %
621 2>{},
622 I0,
623 I0,
624 k0,
625 I0,
627 a_thread_buf);
628 });
629 });
630 }
631 else if constexpr(m0.value == (MRepeat - 1))
632 {
633 static_for<0, KRepeat, 1>{}([&](auto k0) {
634 static_for<0, KGroup, 1>{}([&](auto kg0) {
635 a_thread_copy_.Run(
637 make_tuple(Number<(m0 + 2) % MRepeat>{},
638 I0,
639 I0,
641 I0,
642 I0),
643 a_block_buf.At(local_read_buf),
646 Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) %
647 2>{},
648 I0,
649 I0,
650 k0,
651 I0,
653 a_thread_buf);
654 });
655 });
656 }
657 else
658 {
659 static_for<0, KRepeat, 1>{}([&](auto k0) {
660 static_for<0, KGroup, 1>{}([&](auto kg0) {
661 a_thread_copy_.Run(
663 make_tuple(Number<(m0 + 2) % MRepeat>{},
664 I0,
665 I0,
667 I0,
668 I0),
669 a_block_buf.At(mfma_reg_buf),
672 Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) %
673 2>{},
674 I0,
675 I0,
676 k0,
677 I0,
679 a_thread_buf);
680 });
681 });
682 }
683 });
685 };
686
687 LoopFunc(I0, I1);
688 LoopFunc(I1, I0);
689
690 i += 2;
691 } while(i < (num_loop - 2));
692 }
693 // tail
694 if constexpr(TailNum == TailNumber::Even)
695 {
696 b_blockwise_copy.Run(b_grid_desc,
697 b_grid_buf,
699 b_block_origin_idx,
700 b_thread_bufs(I1));
701
702 b_blockwise_copy_up.Run(b_grid_desc,
703 b_grid_buf_up,
705 b_block_origin_idx,
706 b_thread_bufs_up(I1));
707 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1));
708 static_for<0, MRepeat, 1>{}([&](auto m0) {
709 static_for<0, KRepeat, 1>{}([&](auto k0) {
710 static_for<0, NRepeat, 1>{}([&](auto n0) {
714
715 static_for<0, KPack, 1>{}([&](auto ik) {
716 a_thread_vec.template AsType<ComputeDataType>()(ik) =
717 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
718 make_tuple(m0 % 2, I0, I0, k0, I0, ik))>{}];
719 b_thread_vec.template AsType<ComputeDataType>()(ik) =
720 b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
721 make_tuple(n0, I0, k0, ik))>{}];
722
723 b_thread_vec_up.template AsType<ComputeDataType>()(ik) =
724 b_thread_bufs_up[I0][Number<b_thread_desc_.CalculateOffset(
725 make_tuple(n0, I0, k0, ik))>{}];
726 });
727
728 using mfma_input_type =
729 typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
730
731 constexpr index_t c_offset =
732 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
733
734 xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
735 b_thread_vec.template AsType<mfma_input_type>(),
736 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
737
738 xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
739 b_thread_vec_up.template AsType<mfma_input_type>(),
740 c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
741 });
742 });
743 if constexpr(m0.value == (MRepeat - 2))
744 {
746
747 static_for<0, KRepeat, 1>{}([&](auto k0) {
748 static_for<0, KGroup, 1>{}([&](auto kg0) {
749 a_thread_copy_.Run(
751 make_tuple(Number<(m0 + 2) % MRepeat>{},
752 I0,
753 I0,
755 I0,
756 I0),
757 a_block_buf.At(I1),
760 Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
761 a_thread_buf);
762 });
763 });
764 }
765 else if constexpr(m0.value == MRepeat - 1)
766 {
767 static_for<0, KRepeat, 1>{}([&](auto k0) {
768 static_for<0, KGroup, 1>{}([&](auto kg0) {
769 a_thread_copy_.Run(
771 make_tuple(Number<(m0 + 2) % MRepeat>{},
772 I0,
773 I0,
775 I0,
776 I0),
777 a_block_buf.At(I1),
780 Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
781 a_thread_buf);
782 });
783 });
784 }
785 else
786 {
787 static_for<0, KRepeat, 1>{}([&](auto k0) {
788 static_for<0, KGroup, 1>{}([&](auto kg0) {
789 a_thread_copy_.Run(
791 make_tuple(Number<(m0 + 2) % MRepeat>{},
792 I0,
793 I0,
795 I0,
796 I0),
797 a_block_buf.At(I0),
800 Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
801 a_thread_buf);
802 });
803 });
804 }
805 });
806
808
809 static_for<0, MRepeat, 1>{}([&](auto m0) {
810 static_for<0, KRepeat, 1>{}([&](auto k0) {
811 static_for<0, NRepeat, 1>{}([&](auto n0) {
815
816 static_for<0, KPack, 1>{}([&](auto ik) {
817 a_thread_vec.template AsType<ComputeDataType>()(ik) =
818 a_thread_buf[Number<a_thread_desc_.CalculateOffset(make_tuple(
819 (m0 + HotloopLocalBufSwitch) % 2, I0, I0, k0, I0, ik))>{}];
820 b_thread_vec.template AsType<ComputeDataType>()(ik) =
821 b_thread_bufs[I1][Number<b_thread_desc_.CalculateOffset(
822 make_tuple(n0, I0, k0, ik))>{}];
823 b_thread_vec_up.template AsType<ComputeDataType>()(ik) =
824 b_thread_bufs_up[I1][Number<b_thread_desc_.CalculateOffset(
825 make_tuple(n0, I0, k0, ik))>{}];
826 });
827
828 using mfma_input_type =
829 typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
830
831 constexpr index_t c_offset =
832 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
833
834 xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
835 b_thread_vec.template AsType<mfma_input_type>(),
836 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
837
838 xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
839 b_thread_vec_up.template AsType<mfma_input_type>(),
840 c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
841 });
842 });
843
844 if constexpr(m0.value < (MRepeat - 2))
845 {
846 static_for<0, KRepeat, 1>{}([&](auto k0) {
847 static_for<0, KGroup, 1>{}([&](auto kg0) {
848 a_thread_copy_.Run(
852 a_block_buf.At(I1),
854 make_tuple(Number<(m0 + 2 + HotloopLocalBufSwitch) % 2>{},
855 I0,
856 I0,
857 k0,
858 I0,
860 a_thread_buf);
861 });
862 });
863 }
864 });
865
867 // Let's leak last MFMA block to epilogue region, cover the potential lds-shuffle
868 // latency
869 }
870 else if constexpr(TailNum == TailNumber::Odd)
871 {
872 static_for<0, MRepeat, 1>{}([&](auto m0) {
873 static_for<0, KRepeat, 1>{}([&](auto k0) {
874 static_for<0, NRepeat, 1>{}([&](auto n0) {
878
879 static_for<0, KPack, 1>{}([&](auto ik) {
880 a_thread_vec.template AsType<ComputeDataType>()(ik) =
881 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
882 make_tuple(m0 % 2, I0, I0, k0, I0, ik))>{}];
883 b_thread_vec.template AsType<ComputeDataType>()(ik) =
884 b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
885 make_tuple(n0, I0, k0, ik))>{}];
886 b_thread_vec_up.template AsType<ComputeDataType>()(ik) =
887 b_thread_bufs_up[I0][Number<b_thread_desc_.CalculateOffset(
888 make_tuple(n0, I0, k0, ik))>{}];
889 });
890
891 using mfma_input_type =
892 typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
893
894 constexpr index_t c_offset =
895 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
896
897 xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
898 b_thread_vec.template AsType<mfma_input_type>(),
899 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
900 xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
901 b_thread_vec_up.template AsType<mfma_input_type>(),
902 c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
903 });
904 });
905
906 if constexpr(m0.value < (MRepeat - 2))
907 {
908 static_for<0, KRepeat, 1>{}([&](auto k0) {
909 static_for<0, KGroup, 1>{}([&](auto kg0) {
910 a_thread_copy_.Run(
914 a_block_buf.At(I0),
917 Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
918 a_thread_buf);
919 });
920 });
921 }
922 });
923 }
924 }
925
926 protected:
927 // MRepeat MWave MLane KRepeat KLane KPack
928 // KRepeat -> MRepeat-> Mwave->KLane->MLane->KPack
929 // Reduce the vgpr usage here.
932
934 ComputeDataType,
936 decltype(a_thread_desc_),
937 Sequence<1, 1, 1, 1, 1, KPack / KGroup>,
939 5,
940 A_K1,
941 A_K1>;
942
944
947
948 static constexpr BTileDesc b_block_desc_n0_n1_k0_k1;
949
951};
952
953} // namespace ck
__host__ __device__ constexpr auto integer_divide_floor(X x, Y y)
Definition utility/math.hpp:66
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
Definition ck.hpp:268
__host__ __device__ constexpr auto make_static_buffer(Number< N >)
Definition static_buffer.hpp:186
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition utility/statically_indexed_array.hpp:45
int32_t index_t
Definition ck.hpp:299
integral_constant< index_t, N > Number
Definition number.hpp:12
TailNumber
Definition blkgemmpipe_scheduler.hpp:31
@ Even
Definition blkgemmpipe_scheduler.hpp:34
@ Odd
Definition blkgemmpipe_scheduler.hpp:33
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__device__ void block_sync_lds()
Definition synchronization.hpp:16
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
__host__ __device__ BlockwiseGemmXdlops_pipeline_base(Tuple4 a_origin=CalculateAThreadOriginDataIndex(), Tuple4 b_origin=CalculateBThreadOriginDataIndex())
Constructor for BlockwiseGemmXdlops_pipeline_base.
Definition blockwise_gemm_pipeline_xdlops_base.hpp:222
__host__ static __device__ constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:280
static constexpr index_t MWaves
Definition blockwise_gemm_pipeline_xdlops_base.hpp:44
__host__ static __device__ constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:239
static constexpr auto c_thread_desc_
Definition blockwise_gemm_pipeline_xdlops_base.hpp:378
static constexpr auto xdlops_gemm
Definition blockwise_gemm_pipeline_xdlops_base.hpp:54
static constexpr index_t KGroup
Definition blockwise_gemm_pipeline_xdlops_base.hpp:67
static constexpr auto I1
Definition blockwise_gemm_pipeline_xdlops_base.hpp:37
__host__ static __device__ constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:266
__host__ static __device__ constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:294
static constexpr index_t AMmaKStride
Definition blockwise_gemm_pipeline_xdlops_base.hpp:60
static __device__ auto CalculateAThreadOriginDataIndex6D()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:136
static constexpr index_t WaveSize
Definition blockwise_gemm_pipeline_xdlops_base.hpp:46
__host__ static __device__ constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:253
static constexpr index_t B_K1
Definition blockwise_gemm_pipeline_xdlops_base.hpp:51
ck::BlockwiseGemmXdlops_pipeline_hotloop_inst< BlockSize, MPerBlock, NPerBlock, KPerBlock, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, A_K1, B_K1, A_K1, B_K1, MRepeat, NRepeat, MPerXDL, NPerXDL, xdlops_gemm.KPerXdlops > HotLoopInstList
Definition blockwise_gemm_pipeline_xdlops_base.hpp:82
__host__ __device__ constexpr auto & GetCThreadBuffer()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:111
static constexpr auto I0
Definition blockwise_gemm_pipeline_xdlops_base.hpp:36
static __device__ auto CalculateCThreadOriginDataIndex(Number< m0 >, Number< n0 >, Number< xdlops_i >, Number< blk_i >)
Definition blockwise_gemm_pipeline_xdlops_base.hpp:160
static __device__ auto CalculateCThreadOriginDataIndex8D(Number< m0 >, Number< n0 >, Number< xdlops_i >, Number< blk_i >)
Definition blockwise_gemm_pipeline_xdlops_base.hpp:189
static constexpr index_t KRepeat
Definition blockwise_gemm_pipeline_xdlops_base.hpp:64
static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k
Definition blockwise_gemm_pipeline_xdlops_base.hpp:359
static constexpr auto I2
Definition blockwise_gemm_pipeline_xdlops_base.hpp:38
static constexpr index_t A_K1
Definition blockwise_gemm_pipeline_xdlops_base.hpp:50
static constexpr index_t BMmaKStride
Definition blockwise_gemm_pipeline_xdlops_base.hpp:61
__host__ static __device__ constexpr auto MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N &c_grid_desc_g_m_n)
Definition blockwise_gemm_pipeline_xdlops_base.hpp:341
__host__ static __device__ constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:307
__host__ static __device__ constexpr auto MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N &c_grid_desc_m_n)
Definition blockwise_gemm_pipeline_xdlops_base.hpp:324
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, BBlockTransfer &b_blockwise_copy, BBlockTransfer &b_blockwise_copy_up, const BGridBuffer &b_grid_buf, const BGridBuffer &b_grid_buf_up, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, CThreadBuffer &c_thread_buf_up, index_t num_loop) const
Definition blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v3.hpp:458
ThreadwiseTensorSliceTransfer_v4< ADataType, ComputeDataType, decltype(a_block_desc_m0_m1_m2_k0_k1_k2), decltype(a_thread_desc_), Sequence< 1, 1, 1, 1, 1, KPack/KGroup >, Sequence< 0, 1, 2, 3, 4, 5 >, 5, A_K1, A_K1 > AThreadCopy
Definition blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v3.hpp:933
BlockwiseGemmXdlops_pipeline_base< BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack > Base
Definition blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v3.hpp:102
Definition blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v3.hpp:37
Definition utility/sequence.hpp:43
Definition threadwise_tensor_slice_transfer.hpp:1260
Definition functional2.hpp:33
Definition dtype_vector.hpp:10