blockwise_gemm_pipeline_xdlops_v1_ab_scale.hpp Source File

blockwise_gemm_pipeline_xdlops_v1_ab_scale.hpp Source File#

Composable Kernel: blockwise_gemm_pipeline_xdlops_v1_ab_scale.hpp Source File
blockwise_gemm_pipeline_xdlops_v1_ab_scale.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
7
8namespace ck {
9
10// Compute optimized pipeline
11// GlobalPrefetchStages: 2
12// LocalPreFillStages: 1
13// LocalPreFetchStages: 1
14// LocalSharedMemoryBuffer: 1
15
16template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
17 index_t BlockSize,
18 typename ADataType,
19 typename BDataType,
20 typename ComputeDataType,
21 typename AccDataType,
22 typename ATileDesc,
23 typename BTileDesc,
24 typename AMmaTileDesc,
25 typename BMmaTileDesc,
26 index_t ABlockTransferSrcScalarPerVector,
27 index_t BBlockTransferSrcScalarPerVector,
28 index_t MPerBlock,
29 index_t NPerBlock,
30 index_t KPerBlock,
31 index_t MPerXDL,
32 index_t NPerXDL,
33 index_t MRepeat,
34 index_t NRepeat,
35 index_t KPacks>
39
40template <index_t BlockSize,
41 typename ADataType,
42 typename BDataType,
43 typename ComputeDataType,
44 typename AccDataType,
45 typename ATileDesc,
46 typename BTileDesc,
47 typename AMmaTileDesc,
48 typename BMmaTileDesc,
49 index_t ABlockTransferSrcScalarPerVector,
50 index_t BBlockTransferSrcScalarPerVector,
51 index_t MPerBlock,
52 index_t NPerBlock,
53 index_t KPerBlock,
54 index_t MPerXDL,
55 index_t NPerXDL,
56 index_t MRepeat,
57 index_t NRepeat,
58 index_t KPack
59 // ,bool TransposeC //disable transposec right now...
60 >
62 BlockSize,
63 ADataType,
64 BDataType,
65 ComputeDataType,
66 AccDataType,
67 ATileDesc,
68 BTileDesc,
69 AMmaTileDesc,
70 BMmaTileDesc,
71 ABlockTransferSrcScalarPerVector,
72 BBlockTransferSrcScalarPerVector,
73 MPerBlock,
74 NPerBlock,
75 KPerBlock,
76 MPerXDL,
77 NPerXDL,
78 MRepeat,
79 NRepeat,
80 KPack>
82 ADataType,
83 BDataType,
84 ComputeDataType,
85 AccDataType,
86 ATileDesc,
87 BTileDesc,
88 AMmaTileDesc,
89 BMmaTileDesc,
90 ABlockTransferSrcScalarPerVector,
91 BBlockTransferSrcScalarPerVector,
92 MPerBlock,
93 NPerBlock,
94 KPerBlock,
95 MPerXDL,
96 NPerXDL,
97 MRepeat,
98 NRepeat,
99 KPack,
100 true>
101
102{
104 ADataType,
105 BDataType,
106 ComputeDataType,
107 AccDataType,
108 ATileDesc,
109 BTileDesc,
110 AMmaTileDesc,
111 BMmaTileDesc,
112 ABlockTransferSrcScalarPerVector,
113 BBlockTransferSrcScalarPerVector,
114 MPerBlock,
115 NPerBlock,
116 KPerBlock,
117 MPerXDL,
118 NPerXDL,
119 MRepeat,
120 NRepeat,
121 KPack,
122 true>;
123 using Base::A_K1;
124 using Base::B_K1;
125 using Base::I0;
126 using Base::I1;
127 using Base::KRepeat;
128 using Base::xdlops_gemm;
129 using typename Base::HotLoopInstList;
130
140 using Base::GetWaveIdx;
143
146
148
149 static constexpr index_t AMmaKStride = xdlops_gemm.K0PerXdlops * KPack;
150 static constexpr index_t BMmaKStride = xdlops_gemm.K0PerXdlops * KPack;
151
152 static constexpr index_t PrefetchStages = 2;
153 static constexpr index_t PrefillStages = 1;
154 static constexpr index_t GlobalBufferNum = 1;
155
156 // Force mfma not cross the scaleblock
157 __device__ static auto CalculateAThreadOriginDataIndex()
158 {
159 const auto wave_idx = GetWaveIdx();
160
161 const auto waveId_m = wave_idx[I0];
162
163 const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex();
164
165 return make_tuple(0, waveId_m, xdlops_a_idx[I1], KPack * xdlops_a_idx[I0]);
166 }
167
168 __device__ static auto CalculateBThreadOriginDataIndex()
169 {
170 const auto wave_idx = GetWaveIdx();
171
172 const auto waveId_n = wave_idx[I1];
173
174 const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex();
175
176 return make_tuple(0, waveId_n, xdlops_b_idx[I1], KPack * xdlops_b_idx[I0]);
177 }
178
179 __host__ static constexpr bool BlockHasHotloop(index_t num_loop)
180 {
181 return num_loop > PrefetchStages;
182 }
183
184 __host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
185 {
186 return num_loop == 1 ? TailNumber::Odd : TailNumber::Full;
187 }
188
189 __device__ static constexpr auto HotLoopScheduler()
190 {
191 // A/B split schedule
192 // compiler is likely to use ds_read2 when instruction width smaller than 16bytes
193 constexpr auto num_ds_read_inst_a =
194 HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16
197 constexpr auto num_ds_read_inst_b =
198 HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16
201
202 constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num;
203 constexpr auto num_ds_write_inst_b = HotLoopInstList::B_LDS_Write_Inst_Num;
204
205 constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
206 constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num;
207
208 constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num;
209
210 constexpr auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle;
211 constexpr auto ds_read_a_issue_cycle =
212 HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4;
213 constexpr auto ds_read_b_issue_cycle =
214 HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 ? 8 : 4;
215 constexpr auto ds_read_a_mfma_rate =
216 (mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle);
217 constexpr auto ds_read_b_mfma_rate =
218 (mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle);
219
220 constexpr auto num_dsread_a_mfma =
221 (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate;
222 constexpr auto num_dsread_b_mfma =
223 (num_ds_read_inst_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate;
224
225 // stage 1
226 // Separate this part?
227 // constexpr auto num_mfma_per_ds_read = sizeof(ComputeDataTypeBuf) / sizeof(ADataType) >
228 // sizeof(ComputeDataTypeBuf) /
229 // sizeof(BDataType)
230 // ? sizeof(ComputeDataTypeBuf) /
231 // sizeof(ADataType) : sizeof(ComputeDataTypeBuf)
232 // / sizeof(BDataType);
233 constexpr auto num_mfma_stage1 = num_mfma_inst - (num_dsread_a_mfma + num_dsread_b_mfma);
234 constexpr auto num_mfma_per_issue =
235 num_mfma_stage1 / (num_buffer_load_inst_a + num_buffer_load_inst_b);
236 constexpr auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a;
237 constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b;
238
240 ignore = i;
241 static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) {
242 ignore = idswrite;
243 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
244 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
245 });
246 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
247 __builtin_amdgcn_sched_group_barrier(
248 0x008, num_mfma_per_issue - num_dswrite_per_issue_a, 0); // MFMA
249 });
251 ignore = i;
252 static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) {
253 ignore = idswrite;
254 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
255 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
256 });
257 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
258 __builtin_amdgcn_sched_group_barrier(
259 0x008, num_mfma_per_issue - num_dswrite_per_issue_b, 0); // MFMA
260 });
261
262 // stage 2
264 if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >=
265 ds_read_a_mfma_rate)
266 {
267 __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
268 }
269 else
270 {
271 __builtin_amdgcn_sched_group_barrier(0x100,
272 num_ds_read_inst_a - (num_dsread_a_mfma - 1) *
273 ds_read_a_mfma_rate,
274 0); // DS read
275 }
276 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
277 });
278
280 if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_mfma_rate) >=
281 ds_read_b_mfma_rate)
282 {
283 __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read
284 }
285 else
286 {
287 __builtin_amdgcn_sched_group_barrier(0x100,
288 num_ds_read_inst_b - (num_dsread_b_mfma - 1) *
289 ds_read_b_mfma_rate,
290 0); // DS read
291 }
292 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
293 });
294 }
295
296 template <bool HasMainLoop,
297 int NumKBlockPerScale,
298 TailNumber TailNum,
299 typename AGridDesc,
300 typename ABlockDesc,
301 typename ABlockTransfer,
302 typename AGridBuffer,
303 typename ABlockBuffer,
304 typename ABlockTransferStep,
305 typename BGridDesc,
306 typename BBlockDesc,
307 typename BBlockTransfer,
308 typename BGridBuffer,
309 typename BBlockBuffer,
310 typename BBlockTransferStep,
311 typename CScaleThreadDesc,
312 typename CThreadBuffer,
313 typename AScaleGridBuffer,
314 typename AScaleGridDesc,
315 typename AScaleThreadDesc,
316 typename AScaleThreadTransfer,
317 typename AScaleThreadTransferStep,
318 typename BScaleGridBuffer,
319 typename BScaleGridDesc,
320 typename BScaleThreadDesc,
321 typename BScaleThreadTransfer,
322 typename BScaleThreadTransferStep>
323 __device__ void Run(
324 // ABlockCopy
325 const AGridDesc& a_grid_desc,
326 const ABlockDesc& a_block_desc,
327 ABlockTransfer& a_blockwise_copy,
328 const AGridBuffer& a_grid_buf,
329 ABlockBuffer& a_block_buf,
330 const ABlockTransferStep& a_block_copy_step,
331 // BBlockCopy
332 const BGridDesc& b_grid_desc,
333 const BBlockDesc& b_block_desc,
334 BBlockTransfer& b_blockwise_copy,
335 const BGridBuffer& b_grid_buf,
336 BBlockBuffer& b_block_buf,
337 const BBlockTransferStep& b_block_copy_step,
338 // CThread
339 const CScaleThreadDesc& c_scale_thread_desc,
340 CThreadBuffer& c_thread_buf,
341 // AScaleThreadCopy
342 const AScaleGridDesc& a_scale_grid_desc,
343 const AScaleThreadDesc& a_scale_thread_desc,
344 AScaleThreadTransfer& a_scale_thread_copy,
345 const AScaleGridBuffer& a_scale_grid_buf,
346 const AScaleThreadTransferStep& a_scale_thread_copy_step,
347 // BScaleThreadCopy
348 const BScaleGridDesc& b_scale_grid_desc,
349 const BScaleThreadDesc& b_scale_thread_desc,
350 BScaleThreadTransfer& b_scale_thread_copy,
351 const BScaleGridBuffer& b_scale_grid_buf,
352 const BScaleThreadTransferStep& b_scale_thread_copy_step,
353 // num_loop
354 index_t num_loop) const
355 {
356 __builtin_amdgcn_sched_barrier(0);
357 // assume kperblock = scaleblockk
359 a_thread_desc_.GetElementSpaceSize());
361 b_thread_desc_.GetElementSpaceSize());
363 a_scale_thread_desc.GetElementSpaceSize());
365 b_scale_thread_desc.GetElementSpaceSize());
367 c_scale_thread_desc.GetElementSpaceSize());
368
369 // Global prefetch 1
370 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
371 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
372
373 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
374 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
375
376 static_for<0, MRepeat, 1>{}([&](auto m0) {
377 a_scale_thread_copy.Run(a_scale_grid_desc,
378 a_scale_grid_buf,
379 a_scale_thread_desc,
380 make_tuple(m0, I0),
381 a_scale_thread_buf);
382 a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
383 a_scale_thread_copy_step.At(Number<0>{}));
384 });
385
386 if constexpr(NumKBlockPerScale == 1)
387 {
388 a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
389 a_scale_thread_copy_step.At(Number<2>{}));
390 }
391 else
392 {
393 a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
394 a_scale_thread_copy_step.At(Number<1>{}));
395 }
396
397 b_scale_thread_copy.Run(b_scale_grid_desc,
398 b_scale_grid_buf,
399 b_scale_thread_desc,
400 make_tuple(I0, I0),
401 b_scale_thread_buf);
402
403 b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
404
405 constexpr auto num_scale_k_block = CScaleThreadDesc{}.GetLength(Number<0>{});
406 constexpr auto num_scale_m_block = CScaleThreadDesc{}.GetLength(Number<1>{});
407 constexpr auto num_scale_n_block = CScaleThreadDesc{}.GetLength(Number<2>{});
408
412 constexpr index_t c_offset =
413 CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0));
414 constexpr index_t a_offset =
415 AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0));
416 constexpr index_t b_offset =
417 BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0));
418
419 c_scale_thread_buf(Number<c_offset>{}) =
420 a_scale_thread_buf[Number<a_offset>{}] *
421 b_scale_thread_buf[Number<b_offset>{}];
422 });
423 });
424 });
425
426 // Local prefill 1
427 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
428 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
429
430 // Global prefetch 2
431 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
432 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
433
434 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
435 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
436
437 static_for<0, MRepeat, 1>{}([&](auto m0) {
438 a_scale_thread_copy.Run(a_scale_grid_desc,
439 a_scale_grid_buf,
440 a_scale_thread_desc,
441 make_tuple(m0, I0),
442 a_scale_thread_buf);
443 a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
444 a_scale_thread_copy_step.At(Number<0>{}));
445 });
446
447 if constexpr(NumKBlockPerScale == 1)
448 {
449 a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
450 a_scale_thread_copy_step.At(Number<2>{}));
451 }
452 else
453 {
454 a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
455 a_scale_thread_copy_step.At(Number<1>{}));
456 }
457
458 b_scale_thread_copy.Run(b_scale_grid_desc,
459 b_scale_grid_buf,
460 b_scale_thread_desc,
461 make_tuple(I0, I0),
462 b_scale_thread_buf);
463
464 b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
465
466 // Initialize C
467 c_thread_buf.Clear();
468
470 AccDataType,
471 1,
472 xdlops_gemm.GetRegSizePerXdlops(),
473 true>
474 c_thread_buf_per_scale;
475
476 // Local prefetch 1
478 static_for<0, KRepeat, 1>{}([&](auto k0) {
479 static_for<0, MRepeat, 1>{}([&](auto m0) {
482 a_block_buf,
484 make_tuple(m0, I0, k0, I0),
485 a_thread_buf);
486 });
487 static_for<0, NRepeat, 1>{}([&](auto n0) {
490 b_block_buf,
492 make_tuple(n0, I0, k0, I0),
493 b_thread_buf);
494 });
495 });
496
497 __builtin_amdgcn_sched_barrier(0);
498
499 // main body
500 if constexpr(HasMainLoop)
501 {
502 index_t i = 0;
503 do
504 {
506 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
507 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
508
509 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
510 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
511
512 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
513 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
514
515 static_for<0, MRepeat, 1>{}([&](auto m0) {
516 static_for<0, NRepeat, 1>{}([&](auto n0) {
517 static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) {
518 static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
519 c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
520 .template AsType<AccDataType>()(Number<t>{}) = 0;
521 });
522 static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) {
525
526 static_for<0, KPack, 1>{}([&](auto ik) {
527 a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
528 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
529 make_tuple(m0,
530 I0,
531 kscale0 * KRepeat / num_scale_k_block + k0,
532 ik))>{}];
533 b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
534 b_thread_buf[Number<b_thread_desc_.CalculateOffset(
535 make_tuple(n0,
536 I0,
537 kscale0 * KRepeat / num_scale_k_block + k0,
538 ik))>{}];
539 });
540
541 using mfma_input_type =
543 xdlops_gemm.K1PerXdlops>::type;
544
545 xdlops_gemm.template Run<>(
546 a_thread_vec.template AsType<mfma_input_type>(),
547 b_thread_vec.template AsType<mfma_input_type>(),
548 c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
549 });
550 static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
551 constexpr index_t c_offset =
552 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
553 constexpr index_t cscale_offset =
554 CScaleThreadDesc{}.CalculateOffset(
555 make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat));
556
557 c_thread_buf(Number<c_offset>{}) +=
558 c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
559 .template AsType<AccDataType>()[Number<t>{}] *
561 c_scale_thread_buf[Number<cscale_offset>{}]);
562 });
563 });
564 });
565 });
566
567 static_for<0, MRepeat, 1>{}([&](auto m0) {
570 constexpr index_t c_offset =
571 CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0));
572 constexpr index_t a_offset =
573 AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0));
574 constexpr index_t b_offset =
575 BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0));
576
577 c_scale_thread_buf(Number<c_offset>{}) =
578 a_scale_thread_buf[Number<a_offset>{}] *
579 b_scale_thread_buf[Number<b_offset>{}];
580 });
581 });
582 });
583
585 static_for<0, KRepeat, 1>{}([&](auto k) {
586 static_for<0, MRepeat, 1>{}([&](auto m0) {
589 a_block_buf,
591 make_tuple(m0, I0, k, I0),
592 a_thread_buf);
593 });
594 static_for<0, NRepeat, 1>{}([&](auto n0) {
597 b_block_buf,
599 make_tuple(n0, I0, k, I0),
600 b_thread_buf);
601 });
602 });
603
604 static_for<0, MRepeat, 1>{}([&](auto m0) {
605 a_scale_thread_copy.Run(a_scale_grid_desc,
606 a_scale_grid_buf,
607 a_scale_thread_desc,
608 make_tuple(m0, I0),
609 a_scale_thread_buf);
610 a_scale_thread_copy.MoveSrcSliceWindow(
611 a_scale_grid_desc, a_scale_thread_copy_step.At(Number<0>{}));
612 });
613
614 if constexpr(NumKBlockPerScale == 1)
615 {
616 a_scale_thread_copy.MoveSrcSliceWindow(
617 a_scale_grid_desc, a_scale_thread_copy_step.At(Number<2>{}));
618 }
619 else
620 {
621 a_scale_thread_copy.MoveSrcSliceWindow(
622 a_scale_grid_desc, a_scale_thread_copy_step.At(Number<1>{}));
623 }
624
625 b_scale_thread_copy.Run(b_scale_grid_desc,
626 b_scale_grid_buf,
627 b_scale_thread_desc,
628 make_tuple(I0, I0),
629 b_scale_thread_buf);
630
631 b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
633 __builtin_amdgcn_sched_barrier(0);
634 i += 1;
635 } while(i < (num_loop - 2));
636 }
637
638 // tail
639 if constexpr(TailNum == TailNumber::Full)
640 {
642 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
643 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
644
645 static_for<0, MRepeat, 1>{}([&](auto m0) {
646 static_for<0, NRepeat, 1>{}([&](auto n0) {
647 static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) {
648 static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
649 c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
650 .template AsType<AccDataType>()(Number<t>{}) = 0;
651 });
652 static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) {
655
656 static_for<0, KPack, 1>{}([&](auto ik) {
657 a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
658 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
659 make_tuple(m0,
660 I0,
661 kscale0 * KRepeat / num_scale_k_block + k0,
662 ik))>{}];
663 b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
664 b_thread_buf[Number<b_thread_desc_.CalculateOffset(
665 make_tuple(n0,
666 I0,
667 kscale0 * KRepeat / num_scale_k_block + k0,
668 ik))>{}];
669 });
670
671 using mfma_input_type =
673 xdlops_gemm.K1PerXdlops>::type;
674
675 xdlops_gemm.template Run<>(
676 a_thread_vec.template AsType<mfma_input_type>(),
677 b_thread_vec.template AsType<mfma_input_type>(),
678 c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
679 });
680 static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
681 constexpr index_t c_offset =
682 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
683 constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset(
684 make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat));
685
686 c_thread_buf(Number<c_offset>{}) +=
687 c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
688 .template AsType<AccDataType>()[Number<t>{}] *
690 c_scale_thread_buf[Number<cscale_offset>{}]);
691 });
692 });
693 });
694 });
695
696 static_for<0, MRepeat, 1>{}([&](auto m0) {
699 constexpr index_t c_offset =
700 CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0));
701 constexpr index_t a_offset =
702 AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0));
703 constexpr index_t b_offset =
704 BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0));
705
706 c_scale_thread_buf(Number<c_offset>{}) =
707 a_scale_thread_buf[Number<a_offset>{}] *
708 b_scale_thread_buf[Number<b_offset>{}];
709 });
710 });
711 });
712
714 static_for<0, KRepeat, 1>{}([&](auto k) {
715 static_for<0, MRepeat, 1>{}([&](auto m0) {
718 a_block_buf,
720 make_tuple(m0, I0, k, I0),
721 a_thread_buf);
722 });
723 static_for<0, NRepeat, 1>{}([&](auto n0) {
726 b_block_buf,
728 make_tuple(n0, I0, k, I0),
729 b_thread_buf);
730 });
731 });
732
734 __builtin_amdgcn_sched_barrier(0);
735
736 static_for<0, MRepeat, 1>{}([&](auto m0) {
737 static_for<0, NRepeat, 1>{}([&](auto n0) {
738 static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) {
739 static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
740 c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
741 .template AsType<AccDataType>()(Number<t>{}) = 0;
742 });
743 static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) {
746
747 static_for<0, KPack, 1>{}([&](auto ik) {
748 a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
749 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
750 make_tuple(m0,
751 I0,
752 kscale0 * KRepeat / num_scale_k_block + k0,
753 ik))>{}];
754 b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
755 b_thread_buf[Number<b_thread_desc_.CalculateOffset(
756 make_tuple(n0,
757 I0,
758 kscale0 * KRepeat / num_scale_k_block + k0,
759 ik))>{}];
760 });
761
762 using mfma_input_type =
764 xdlops_gemm.K1PerXdlops>::type;
765
766 xdlops_gemm.template Run<>(
767 a_thread_vec.template AsType<mfma_input_type>(),
768 b_thread_vec.template AsType<mfma_input_type>(),
769 c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
770 });
771 static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
772 constexpr index_t c_offset =
773 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
774 constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset(
775 make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat));
776
777 c_thread_buf(Number<c_offset>{}) +=
778 c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
779 .template AsType<AccDataType>()[Number<t>{}] *
781 c_scale_thread_buf[Number<cscale_offset>{}]);
782 });
783 });
784 });
785 });
786 __builtin_amdgcn_sched_barrier(0);
787 }
788 else if constexpr(TailNum == TailNumber::Odd)
789 {
790 static_for<0, MRepeat, 1>{}([&](auto m0) {
791 static_for<0, NRepeat, 1>{}([&](auto n0) {
792 static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) {
793 static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
794 c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
795 .template AsType<AccDataType>()(Number<t>{}) = 0;
796 });
797 static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) {
800
801 static_for<0, KPack, 1>{}([&](auto ik) {
802 a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
803 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
804 make_tuple(m0,
805 I0,
806 kscale0 * KRepeat / num_scale_k_block + k0,
807 ik))>{}];
808 b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
809 b_thread_buf[Number<b_thread_desc_.CalculateOffset(
810 make_tuple(n0,
811 I0,
812 kscale0 * KRepeat / num_scale_k_block + k0,
813 ik))>{}];
814 });
815
816 using mfma_input_type =
818 xdlops_gemm.K1PerXdlops>::type;
819
820 xdlops_gemm.template Run<>(
821 a_thread_vec.template AsType<mfma_input_type>(),
822 b_thread_vec.template AsType<mfma_input_type>(),
823 c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
824 });
825 static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
826 constexpr index_t c_offset =
827 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
828 constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset(
829 make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat));
830
831 c_thread_buf(Number<c_offset>{}) +=
832 c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
833 .template AsType<AccDataType>()[Number<t>{}] *
835 c_scale_thread_buf[Number<cscale_offset>{}]);
836 });
837 });
838 });
839 });
840 __builtin_amdgcn_sched_barrier(0);
841 }
842 }
843
844 protected:
845 using Base::a_thread_desc_;
846 using Base::b_thread_desc_;
847 using Base::c_thread_desc_;
850 decltype(a_block_desc_m0_m1_m2_k),
851 decltype(a_thread_desc_),
854 3,
855 A_K1,
856 A_K1>;
857
860 decltype(b_block_desc_n0_n1_n2_k),
861 decltype(b_thread_desc_),
864 3,
865 B_K1,
866 B_K1>;
867
870};
871
872} // namespace ck
Definition ck.hpp:268
__host__ __device__ constexpr auto make_static_buffer(Number< N >)
Definition static_buffer.hpp:186
int32_t index_t
Definition ck.hpp:299
integral_constant< index_t, N > Number
Definition number.hpp:12
TailNumber
Definition blkgemmpipe_scheduler.hpp:31
@ Odd
Definition blkgemmpipe_scheduler.hpp:33
@ Full
Definition blkgemmpipe_scheduler.hpp:49
@ Vgpr
Definition amd_address_space.hpp:20
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__host__ __device__ constexpr Y type_convert(X x)
Definition utility/type_convert.hpp:98
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__device__ void block_sync_lds()
Definition synchronization.hpp:16
ck::BlockwiseGemmXdlops_pipeline_hotloop_inst< BlockSize, MPerBlock, NPerBlock, KPerBlock, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, A_K1, B_K1, A_K1, B_K1, MRepeat, NRepeat, MPerXDL, NPerXDL, xdlops_gemm.KPerXdlops > HotLoopInstList
Definition blockwise_gemm_pipeline_xdlops_base.hpp:82
ThreadwiseTensorSliceTransfer_v4< BDataType, ComputeDataTypeBuf, decltype(b_block_desc_n0_n1_n2_k), decltype(b_thread_desc_), Sequence< 1, 1, 1, KPack >, Sequence< 0, 1, 2, 3 >, 3, B_K1, B_K1 > BThreadCopy
Definition blockwise_gemm_pipeline_xdlops_v1_ab_scale.hpp:858
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, const CScaleThreadDesc &c_scale_thread_desc, CThreadBuffer &c_thread_buf, const AScaleGridDesc &a_scale_grid_desc, const AScaleThreadDesc &a_scale_thread_desc, AScaleThreadTransfer &a_scale_thread_copy, const AScaleGridBuffer &a_scale_grid_buf, const AScaleThreadTransferStep &a_scale_thread_copy_step, const BScaleGridDesc &b_scale_grid_desc, const BScaleThreadDesc &b_scale_thread_desc, BScaleThreadTransfer &b_scale_thread_copy, const BScaleGridBuffer &b_scale_grid_buf, const BScaleThreadTransferStep &b_scale_thread_copy_step, index_t num_loop) const
Definition blockwise_gemm_pipeline_xdlops_v1_ab_scale.hpp:323
ThreadwiseTensorSliceTransfer_v4< ADataType, ComputeDataTypeBuf, decltype(a_block_desc_m0_m1_m2_k), decltype(a_thread_desc_), Sequence< 1, 1, 1, KPack >, Sequence< 0, 1, 2, 3 >, 3, A_K1, A_K1 > AThreadCopy
Definition blockwise_gemm_pipeline_xdlops_v1_ab_scale.hpp:848
BlockwiseGemmXdlops_pipeline_base< BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack, true > Base
Definition blockwise_gemm_pipeline_xdlops_v1_ab_scale.hpp:103
Definition blockwise_gemm_pipeline_xdlops_v1_ab_scale.hpp:37
Definition utility/sequence.hpp:43
Definition static_buffer.hpp:75
Definition threadwise_tensor_slice_transfer.hpp:1260
Definition functional2.hpp:33
Definition dtype_vector.hpp:10