gridwise_gemm_xdlops_splitk_lds_direct_load.hpp Source File

gridwise_gemm_xdlops_splitk_lds_direct_load.hpp Source File#

Composable Kernel: gridwise_gemm_xdlops_splitk_lds_direct_load.hpp Source File
gridwise_gemm_xdlops_splitk_lds_direct_load.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
21
22namespace ck {
23
24template <typename GridwiseGemm,
25 bool HasMainKBlockLoop,
26 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
27 typename Block2CTileMap,
28 typename AElementwiseOperation,
29 typename BElementwiseOperation,
30 typename CElementwiseOperation>
31__global__ void
32#if CK_USE_LAUNCH_BOUNDS
34#endif
35 kernel_gemm_xdlops_splitk_lds_direct_load(typename GridwiseGemm::Argument karg,
36 const Block2CTileMap& b2c_map,
37 const AElementwiseOperation a_element_op,
38 const BElementwiseOperation b_element_op,
39 const CElementwiseOperation c_element_op)
40{
41#if defined(__gfx9__)
42 if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
43 {
44 constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte();
45
46 __shared__ uint8_t p_shared[shared_size];
47
48 GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation>(
49 karg, static_cast<void*>(p_shared), b2c_map, a_element_op, b_element_op, c_element_op);
50 }
51#else
52 ignore = karg;
53 ignore = b2c_map;
54 ignore = a_element_op;
55 ignore = b_element_op;
56 ignore = c_element_op;
57#endif // end of if (defined(__gfx9__))
58}
59
60template <index_t BlockSize,
61 typename FloatA,
62 typename FloatB,
63 typename FloatAcc,
64 typename FloatC,
65 typename ALayout,
66 typename BLayout,
67 typename CLayout,
68 typename AElementwiseOperation,
69 typename BElementwiseOperation,
70 typename CElementwiseOperation,
72 index_t NumGemmKPrefetchStage,
73 index_t MPerBlock,
74 index_t NPerBlock,
75 index_t K0PerBlock,
76 index_t MPerXdl,
77 index_t NPerXdl,
78 index_t K1Value,
79 index_t MRepeat,
80 index_t NRepeat,
81 typename ABlockTransferThreadClusterLengths_K0_M_K1,
82 typename ABlockTransferSrcAccessOrder,
83 index_t ABlockTransferSrcVectorDim,
84 index_t ABlockTransferSrcScalarPerVector,
85 bool ABlockLdsExtraM,
86 typename BBlockTransferThreadClusterLengths_K0_N_K1,
87 typename BBlockTransferSrcAccessOrder,
88 index_t BBlockTransferSrcVectorDim,
89 index_t BBlockTransferSrcScalarPerVector,
90 bool BBlockLdsExtraN,
91 index_t CShuffleMRepeatPerShuffle,
92 index_t CShuffleNRepeatPerShuffle,
93 index_t CBlockTransferScalarPerVector_NWaveNPerXDL,
94 typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
97 typename ComputeType = FloatC>
99{
100 static constexpr auto I0 = Number<0>{};
101 static constexpr auto I1 = Number<1>{};
102 static constexpr auto I2 = Number<2>{};
103 static constexpr auto I3 = Number<3>{};
104 static constexpr auto I4 = Number<4>{};
105 static constexpr auto I5 = Number<5>{};
106 static constexpr auto I6 = Number<6>{};
107 static constexpr auto I7 = Number<7>{};
108
109 // K1 should be Number<...>
110 static constexpr auto K1 = Number<K1Value>{};
111 static constexpr auto KPerBlock = Number<K1Value * K0PerBlock>{};
112 static constexpr auto M01 = 1;
113 static constexpr auto N01 = 1;
114
115 static constexpr auto gemm_padder =
117 MPerBlock, NPerBlock, K1* K0PerBlock};
118
120
123
125 {
126 const FloatA* p_a_grid;
127 const FloatB* p_b_grid;
128 FloatC* p_c_grid;
140
141 Argument(const FloatA* p_a_grid_,
142 const FloatB* p_b_grid_,
143 FloatC* p_c_grid_,
144 index_t M_,
145 index_t N_,
146 index_t K_,
147 index_t StrideA_,
148 index_t StrideB_,
149 index_t StrideC_,
150 index_t MPadded_,
151 index_t NPadded_,
152 index_t KPadded_,
153 index_t K0Padded_,
154 index_t k_batch_)
155 : p_a_grid(p_a_grid_),
156 p_b_grid(p_b_grid_),
157 p_c_grid(p_c_grid_),
158 M(M_),
159 N(N_),
160 K(K_),
161 StrideA(StrideA_),
162 StrideB(StrideB_),
163 StrideC(StrideC_),
164 MPadded(MPadded_),
165 NPadded(NPadded_),
166 KPadded(KPadded_),
167 K0Padded(K0Padded_),
168 k_batch(k_batch_)
169 {
170 }
171
172 void Print() const
173 {
174 std::cout << "arg {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", "
175 << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC
176 << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", "
177 << "KP:" << KPadded << ", " << "K0Padded:" << K0Padded << ", "
178 << "KB:" << k_batch << "}" << std::endl;
179 }
180 };
181
182 __host__ __device__ static auto CalculateGridSize(const Argument& karg)
183 {
184 return std::make_tuple(math::integer_divide_ceil(karg.N, NPerBlock),
185 math::integer_divide_ceil(karg.M, MPerBlock),
186 karg.k_batch);
187 }
188
189 // prefer this to be called on host
190 __host__ __device__ static auto CalculateMPadded(index_t M)
191 {
192 return math::integer_least_multiple(M, MPerBlock);
193 }
194
195 __host__ __device__ static auto CalculateNPadded(index_t N)
196 {
197 return math::integer_least_multiple(N, NPerBlock);
198 }
199
200 __host__ __device__ static auto CalculateK0Padded(index_t K, index_t K_Batch = 1)
201 {
202 // k_batch * k0 * k0_per_block * k1
203 auto K_t = K_Batch * K0PerBlock * K1;
204 return (K + K_t - 1) / K_t * K0PerBlock;
205 }
206
207 __host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1)
208 {
209 auto K0Padded = CalculateK0Padded(K, K_Batch);
210 return K_Batch * K0Padded * K1;
211 }
212
213 __host__ __device__ static auto MakeAGridDescriptor_KBatch_K0_M_K1(index_t M,
214 index_t MPad,
215 index_t K,
216 index_t StrideA,
217 index_t KBatch,
218 index_t K0Padded,
219 index_t KPad)
220 {
221 const auto a_grid_desc_m_k = [&]() {
223 {
224 return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
225 }
227 {
228 return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
229 }
230 }();
231
236 {
237
238 const auto a_grid_desc_m_kpad = transform_tensor_descriptor(
239 a_grid_desc_m_k,
243
244 // const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
246 a_grid_desc_m_kpad,
248 make_right_pad_transform(M, MPad - M)),
251 }
252 else if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
254 {
255 // const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
257 a_grid_desc_m_k,
259 make_right_pad_transform(M, MPad - M)),
262 }
263 else
264 {
266 a_grid_desc_m_k,
271 }
272 }
273
274 __host__ __device__ static auto MakeBGridDescriptor_KBatch_K0_N_K1(index_t K,
275 index_t NPad,
276 index_t N,
277 index_t StrideB,
278 index_t KBatch,
279 index_t K0Padded,
280 index_t KPad)
281 {
282 const auto b_grid_desc_k_n = [&]() {
284 {
285 return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1));
286 }
288 {
289 return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB));
290 }
291 }();
292
297 {
298
299 const auto b_grid_desc_kpad_n = transform_tensor_descriptor(
300 b_grid_desc_k_n,
304
305 // const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
307 b_grid_desc_kpad_n,
309 make_right_pad_transform(N, NPad - N)),
312 }
313 else if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
315 {
316 // const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
318 b_grid_desc_k_n,
320 make_right_pad_transform(N, NPad - N)),
323 }
324 else
325 {
327 b_grid_desc_k_n,
332 }
333 }
334
335 __host__ __device__ static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC)
336 {
337 const auto c_grid_desc_m_n = [&]() {
339 {
340 return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
341 }
343 {
344 return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
345 }
346 }();
347
348 return gemm_padder.PadCDescriptor_M_N(c_grid_desc_m_n);
349 }
350
351 __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
352 {
353 constexpr auto max_lds_align = K1;
354
355 // A matrix in LDS memory, dst of blockwise copy
356 constexpr auto a_k0_m_k1_block_desc = [&]() {
357 if constexpr(ABlockLdsExtraM)
358 {
362 }
363 else
364 {
366 make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
367 }
368 }();
369
370 // B matrix in LDS memory, dst of blockwise copy
371 constexpr auto b_k0_n_k1_block_desc = [&]() {
372 if constexpr(BBlockLdsExtraN)
373 {
377 }
378 else
379 {
381 make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
382 }
383 }();
384
385 // LDS allocation for A and B: be careful of alignment
386 constexpr auto a_block_space_size =
387 math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
388
389 constexpr auto b_block_space_size =
390 math::integer_least_multiple(b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align);
391
392 constexpr auto c_block_size =
394
395 return math::max(NumGemmKPrefetchStage * (a_block_space_size + b_block_space_size) *
396 sizeof(ComputeType),
397 c_block_size * sizeof(FloatC));
398 }
399
400 static constexpr index_t MXdlPerWave = MRepeat;
401 static constexpr index_t NXdlPerWave = NRepeat;
403
404 __host__ __device__ static constexpr bool CheckValidity(const Argument& karg)
405 {
410 {
411 if(!(karg.M % MPerBlock == 0))
412 {
413 return false;
414 }
415 }
416
421 {
422 if(!(karg.N % NPerBlock == 0))
423 {
424 return false;
425 }
426 }
427
432 {
433
434 auto K_t = karg.k_batch * K0PerBlock * K1;
435 if(!(karg.K % K_t == 0))
436 {
437 return false;
438 }
439 }
440
442 {
443 if(karg.K % ABlockTransferSrcScalarPerVector != 0)
444 {
445 return false;
446 }
447 }
448 else
449 {
450 if(karg.M % ABlockTransferSrcScalarPerVector != 0)
451 {
452 return false;
453 }
454 }
455
457 {
458 if(karg.N % BBlockTransferSrcScalarPerVector != 0)
459 {
460 return false;
461 }
462 }
463 else
464 {
465 if(karg.K % BBlockTransferSrcScalarPerVector != 0)
466 {
467 return false;
468 }
469 }
470
472 {
473 if(karg.N % CBlockTransferScalarPerVector_NWaveNPerXDL != 0)
474 {
475 return false;
476 }
477 }
478 else
479 {
480 if(karg.M % CBlockTransferScalarPerVector_NWaveNPerXDL != 0)
481 {
482 return false;
483 }
484 }
485
486 const auto num_k_loop = karg.K0Padded / K0PerBlock;
487 if(!GridwiseGemmPipe::IsSupported(num_k_loop))
488 {
489 return false;
490 }
491
492 return true;
493 }
494
495 __host__ __device__ static auto GetKPad(index_t K, index_t KBatch)
496 {
497 const index_t K0Padded =
498 math::integer_divide_ceil(K, K1 * K0PerBlock * KBatch) * K0PerBlock;
499 const index_t KPad = KBatch * K0Padded * K1;
500 return KPad;
501 }
502
503 __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0Padded)
504 {
505 const index_t num_loop = K0Padded / K0PerBlock;
506 return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
507 }
508
509 template <typename CGridDesc>
510 __host__ __device__ static constexpr auto
511 MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc& c_m_n_grid_desc)
512 {
513 const auto M = c_m_n_grid_desc.GetLength(I0);
514 const auto N = c_m_n_grid_desc.GetLength(I1);
515
516 const auto MBlock = M / MPerBlock;
517 const auto NBlock = N / NPerBlock;
518
520 c_m_n_grid_desc,
525 }
526
527 __host__ __device__ static constexpr auto
529 {
530 constexpr index_t MWave = MPerBlock / (MRepeat * MPerXdl);
531 constexpr index_t NWave = NPerBlock / (NRepeat * NPerXdl);
532
536 I1,
538 }
539
540 // return block_id to C matrix tile idx (m0, n0, k_split) mapping
541 __host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap()
542 {
544 }
545
548
549 template <bool HasMainKBlockLoop,
550 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
551 typename Block2CTileMap>
552 __device__ static void Run(const Argument& karg,
553 void* __restrict__ p_shared_block,
554 const Block2CTileMap& block_2_ctile_map,
555 const AElementwiseOperation a_element_op = AElementwiseOperation{},
556 const BElementwiseOperation b_element_op = BElementwiseOperation{},
557 const CElementwiseOperation c_element_op = CElementwiseOperation{})
558 {
559 // Elementwise operations are not supported for A and B, arguments left only for the API
560 // consistency.
561 (void)a_element_op;
562 (void)b_element_op;
563
564 const FloatA* p_a_grid = karg.p_a_grid;
565 const FloatB* p_b_grid = karg.p_b_grid;
566 FloatC* p_c_grid = karg.p_c_grid;
567 const auto a_b_k0_m_k1_grid_desc = MakeAGridDescriptor_KBatch_K0_M_K1(
568 karg.M, karg.MPadded, karg.K, karg.StrideA, karg.k_batch, karg.K0Padded, karg.KPadded);
569 const auto b_b_k0_n_k1_grid_desc = MakeBGridDescriptor_KBatch_K0_N_K1(
570 karg.K, karg.NPadded, karg.N, karg.StrideB, karg.k_batch, karg.K0Padded, karg.KPadded);
571 const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC);
572
573 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
575
576 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
577 p_a_grid, a_b_k0_m_k1_grid_desc.GetElementSpaceSize());
578 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
579 p_b_grid, b_b_k0_n_k1_grid_desc.GetElementSpaceSize());
581 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
582
583 // divide block work by [KBatch, M, N]
584 const auto block_work_idx =
585 block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
586
587 if(!block_2_ctile_map.ValidCTileIndex(
588 block_work_idx,
589 make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
590 c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
591 {
592 return;
593 }
594
595 const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
596 const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I2]);
597 const index_t k_batch_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
598
599 // HACK: this force m/n_block_data_idx_on_grid into SGPR
600 const index_t m_block_data_idx_on_grid =
601 __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
602
603 const index_t n_block_data_idx_on_grid =
604 __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
605
606 // lds max alignment
607 constexpr auto max_lds_align = K1;
608
609 // A matrix in LDS memory, dst of blockwise copy
610 constexpr auto a_k0_m_k1_block_desc = [&]() {
611 if constexpr(ABlockLdsExtraM)
612 {
616 }
617 else
618 {
622 }
623 }();
624
625 constexpr auto a_b_k0_m_k1_block_desc = [&]() {
626 if constexpr(ABlockLdsExtraM)
627 {
632 K1,
633 I1));
634 }
635 else
636 {
641 }
642 }();
643 // B matrix in LDS memory, dst of blockwise copy
644 constexpr auto b_k0_n_k1_block_desc = [&]() {
645 if constexpr(BBlockLdsExtraN)
646 {
650 }
651 else
652 {
656 }
657 }();
658
659 constexpr auto b_b_k0_n_k1_block_desc = [&]() {
660 if constexpr(BBlockLdsExtraN)
661 {
666 K1,
667 I1));
668 }
669 else
670 {
675 }
676 }();
677
678 auto a_blockwise_copy =
679 ThreadGroupTensorSliceTransfer_DirectLoad<ThisThreadBlock,
680 Sequence<1, K0PerBlock, MPerBlock, K1>,
681 ABlockTransferThreadClusterLengths_K0_M_K1,
682 ABlockTransferSrcAccessOrder,
683 FloatA,
684 ComputeType,
685 decltype(a_b_k0_m_k1_grid_desc),
686 decltype(a_b_k0_m_k1_block_desc),
687 ABlockTransferSrcAccessOrder,
688 ABlockTransferSrcVectorDim,
689 3,
690 ABlockTransferSrcScalarPerVector>(
691 a_b_k0_m_k1_grid_desc,
692 make_multi_index(k_batch_id, 0, m_block_data_idx_on_grid, 0),
693 a_b_k0_m_k1_block_desc,
694 make_multi_index(0, 0, 0, 0));
695
696 auto b_blockwise_copy =
697 ThreadGroupTensorSliceTransfer_DirectLoad<ThisThreadBlock,
698 Sequence<1, K0PerBlock, NPerBlock, K1>,
699 BBlockTransferThreadClusterLengths_K0_N_K1,
700 BBlockTransferSrcAccessOrder,
701 FloatB,
702 ComputeType,
703 decltype(b_b_k0_n_k1_grid_desc),
704 decltype(b_b_k0_n_k1_block_desc),
705 BBlockTransferSrcAccessOrder,
706 BBlockTransferSrcVectorDim,
707 3,
708 BBlockTransferSrcScalarPerVector>(
709 b_b_k0_n_k1_grid_desc,
710 make_multi_index(k_batch_id, 0, n_block_data_idx_on_grid, 0),
711 b_b_k0_n_k1_block_desc,
712 make_multi_index(0, 0, 0, 0));
713
715 BlockSize,
716 ComputeType, // ComputeType A
717 ComputeType, // ComputeType B
718 FloatAcc,
719 decltype(a_k0_m_k1_block_desc),
720 decltype(b_k0_n_k1_block_desc),
721 MPerXdl,
722 NPerXdl,
723 MRepeat,
724 NRepeat,
725 K1,
726 LoopSched>();
727
728 auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
729
730 constexpr auto a_block_space_size =
731 math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
732
733 constexpr auto a_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
734 constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
735
736 const auto a_buffers_offset = 0;
737 auto a_block_buffers =
738 ck::lds_utils::AllocateLdsBuffers<ComputeType, NumGemmKPrefetchStage>(
739 p_shared_block,
740 a_b_k0_m_k1_block_desc.GetElementSpaceSize(),
741 a_buffers_offset,
742 max_lds_align);
743 const auto b_buffers_offset = a_block_space_size * NumGemmKPrefetchStage;
744 auto b_block_buffers =
745 ck::lds_utils::AllocateLdsBuffers<ComputeType, NumGemmKPrefetchStage>(
746 p_shared_block,
747 b_b_k0_n_k1_block_desc.GetElementSpaceSize(),
748 b_buffers_offset,
749 max_lds_align);
750
751 // gridwise GEMM pipeline
752 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
753 (a_b_k0_m_k1_grid_desc.GetLength(I1) * a_b_k0_m_k1_grid_desc.GetLength(I3)) /
754 (K0PerBlock * K1));
755
756 const auto gridwise_gemm_pipeline = GridwiseGemmPipe{};
757
758 gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(a_b_k0_m_k1_grid_desc,
759 a_b_k0_m_k1_block_desc,
760 a_blockwise_copy,
761 a_grid_buf,
762 a_block_buffers,
763 a_block_slice_copy_step,
764 b_b_k0_n_k1_grid_desc,
765 b_b_k0_n_k1_block_desc,
766 b_blockwise_copy,
767 b_grid_buf,
768 b_block_buffers,
769 b_block_slice_copy_step,
770 blockwise_gemm,
771 c_thread_buf,
772 num_k_block_main_loop);
773
774 // output: register to global memory
775 {
776 constexpr index_t MWave = MPerBlock / (MRepeat * MPerXdl);
777 constexpr index_t NWave = NPerBlock / (NRepeat * NPerXdl);
778
779 constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc =
780 blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
781
782 constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc =
783 blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
784
785 constexpr auto M0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I0);
786 constexpr auto N0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I1);
787 constexpr auto M1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I2);
788 constexpr auto N1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I3);
789 constexpr auto M2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I4);
790 constexpr auto M3 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I5);
791 constexpr auto M4 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I6);
792 constexpr auto N2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I7);
793
794 constexpr auto c_block_desc_mblock_mperblock_nblock_nperblock =
796
798 static_cast<FloatC*>(p_shared_block),
799 c_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
800
801 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
802 c_block_desc_mblock_mperblock_nblock_nperblock,
804 make_freeze_transform(I0), // freeze mblock
805 make_unmerge_transform(make_tuple(CShuffleMRepeatPerShuffle,
806 M1,
807 M2,
808 M3,
809 M4)), // M1 = MWave, M2 * M3 * M4 = MPerXdl
810 make_freeze_transform(I0), // freeze nblock
811 make_unmerge_transform(make_tuple(CShuffleNRepeatPerShuffle,
812 N1,
813 N2))), // M1 = MWave, M2 * M3 * M4 = MPerXdl
814 make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
816 Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{}));
817
818 // calculate origin of thread output tensor on global memory
819 // blockwise GEMM c matrix starting index
820 const auto c_thread_mtx_on_block =
821 blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
822
823 const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
824 const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
825
826 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
828 make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
829 make_tuple(Sequence<0, 1, 2, 3, 4>{}),
830 make_tuple(Sequence<0>{}));
831
832 const auto m_thread_data_on_block_idx =
833 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
834 make_multi_index(m_thread_data_on_block));
835
836 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
839 make_tuple(Sequence<0, 1, 2>{}),
840 make_tuple(Sequence<0>{}));
841
842 const auto n_thread_data_on_block_idx =
843 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
844 make_multi_index(n_thread_data_on_block));
845
846 // VGPR to LDS
847 auto c_thread_copy_vgpr_to_lds =
848 ThreadwiseTensorSliceTransfer_v1r3<FloatAcc,
849 FloatC,
850 decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc),
851 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
852 ck::tensor_operation::element_wise::PassThrough,
853 Sequence<CShuffleMRepeatPerShuffle,
854 CShuffleNRepeatPerShuffle,
855 I1,
856 I1,
857 M2,
858 I1,
859 M4,
860 I1>,
861 Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
862 7,
863 1,
865 1,
866 true>{
867 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
869 0,
870 m_thread_data_on_block_idx[I1],
871 n_thread_data_on_block_idx[I1],
872 m_thread_data_on_block_idx[I2],
873 m_thread_data_on_block_idx[I3],
874 m_thread_data_on_block_idx[I4],
875 n_thread_data_on_block_idx[I2]),
876 ck::tensor_operation::element_wise::PassThrough{}};
877
878 // LDS to global
879 auto c_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
880 ThisThreadBlock, // index_t BlockSize,
881 CElementwiseOperation, // ElementwiseOperation,
882 CGlobalMemoryDataOperation, // DstInMemOp,
883 Sequence<1,
884 CShuffleMRepeatPerShuffle * MWave * MPerXdl,
885 1,
886 CShuffleNRepeatPerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
887 CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
888 Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
889 FloatC, // typename SrcData,
890 FloatC, // typename DstData,
891 decltype(c_block_desc_mblock_mperblock_nblock_nperblock),
892 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
893 Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
894 3, // index_t VectorDim,
895 CBlockTransferScalarPerVector_NWaveNPerXDL, // index_t ScalarPerVector,
896 true, // bool ThreadTransferSrcResetCoordinateAfterRun,
897 false> // bool ThreadTransferDstResetCoordinateAfterRun
898 {c_block_desc_mblock_mperblock_nblock_nperblock,
899 make_multi_index(0, 0, 0, 0),
900 c_grid_desc_mblock_mperblock_nblock_nperblock,
901 make_multi_index(block_m_id, 0, block_n_id, 0),
902 c_element_op};
903
904 constexpr auto mxdlperwave_forward_step =
905 make_multi_index(0, CShuffleMRepeatPerShuffle * MWave * MPerXdl, 0, 0);
906 constexpr auto nxdlperwave_forward_step =
907 make_multi_index(0, 0, 0, CShuffleNRepeatPerShuffle * NWave * NPerXdl);
908 constexpr auto nxdlperwave_backward_step =
909 make_multi_index(0, 0, 0, -CShuffleNRepeatPerShuffle * NWave * NPerXdl);
910
911 static_for<0, MRepeat, CShuffleMRepeatPerShuffle>{}([&](auto mxdlperwave_iter) {
912 constexpr auto mxdlperwave = mxdlperwave_iter;
913
914 static_for<0, NRepeat, CShuffleNRepeatPerShuffle>{}([&](auto nxdlperwave_iter) {
915 constexpr bool nxdlperwave_forward_sweep =
916 (mxdlperwave % (2 * CShuffleMRepeatPerShuffle) == 0);
917
918 constexpr index_t nxdlperwave_value =
919 nxdlperwave_forward_sweep
920 ? nxdlperwave_iter
921 : (NRepeat - nxdlperwave_iter - CShuffleNRepeatPerShuffle);
922
923 constexpr auto nxdlperwave = Number<nxdlperwave_value>{};
924
925 // make sure it's safe to do ds_write
927
928 // VGPR to LDS
929 c_thread_copy_vgpr_to_lds.Run(
930 c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
931 make_tuple(mxdlperwave, nxdlperwave, I0, I0, I0, I0, I0, I0),
932 c_thread_buf,
933 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
934 c_block_buf);
935
936 // make sure it's safe to do ds_read
938
939 // LDS to global
940 c_block_copy_lds_to_global.Run(c_block_desc_mblock_mperblock_nblock_nperblock,
941 c_block_buf,
942 c_grid_desc_mblock_mperblock_nblock_nperblock,
943 c_grid_buf);
944
945 // move on nxdlperwave dimension
946 if constexpr(nxdlperwave_forward_sweep &&
947 (nxdlperwave < NRepeat - CShuffleNRepeatPerShuffle))
948 {
949 c_block_copy_lds_to_global.MoveDstSliceWindow(
950 c_grid_desc_mblock_mperblock_nblock_nperblock,
951 nxdlperwave_forward_step);
952 }
953 else if constexpr((!nxdlperwave_forward_sweep) && (nxdlperwave > 0))
954 {
955 c_block_copy_lds_to_global.MoveDstSliceWindow(
956 c_grid_desc_mblock_mperblock_nblock_nperblock,
957 nxdlperwave_backward_step);
958 }
959 });
960
961 // move on mxdlperwave dimension
962 if constexpr(mxdlperwave < MRepeat - CShuffleMRepeatPerShuffle)
963 {
964 c_block_copy_lds_to_global.MoveDstSliceWindow(
965 c_grid_desc_mblock_mperblock_nblock_nperblock, mxdlperwave_forward_step);
966 }
967 });
968 }
969 }
970};
971
972} // namespace ck
#define CK_MIN_BLOCK_PER_CU
Definition ck.hpp:31
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
#define IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType_)
Definition device_base.hpp:178
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
Definition utility/math.hpp:78
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
GemmSpecialization
Definition gemm_specialization.hpp:11
@ MKPadding
Definition gemm_specialization.hpp:18
@ KPadding
Definition gemm_specialization.hpp:16
@ NPadding
Definition gemm_specialization.hpp:15
@ MPadding
Definition gemm_specialization.hpp:14
@ MNKPadding
Definition gemm_specialization.hpp:20
@ MNPadding
Definition gemm_specialization.hpp:17
@ NKPadding
Definition gemm_specialization.hpp:19
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector()
Definition blockwise_gemm_xdlops.hpp:620
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
constexpr auto GridwiseGemmPipeline_Selector()
Definition gridwise_gemm_pipeline_selector.hpp:31
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition ck.hpp:277
@ Set
Definition ck.hpp:278
__host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition tensor_description/tensor_adaptor.hpp:425
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
__host__ __device__ constexpr auto make_freeze_transform(const LowerIndex &low_idx)
Definition multi_index_transform_helper.hpp:151
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
integral_constant< index_t, N > Number
Definition number.hpp:12
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
__host__ __device__ constexpr auto make_naive_tensor_descriptor_aligned(const Tuple< Lengths... > &lengths, Align align)
Definition tensor_descriptor_helper.hpp:132
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
LoopScheduler
Definition loop_scheduler.hpp:15
__global__ void kernel_gemm_xdlops_splitk_lds_direct_load(typename GridwiseGemm::Argument karg, const Block2CTileMap &b2c_map, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CElementwiseOperation c_element_op)
Definition gridwise_gemm_xdlops_splitk_lds_direct_load.hpp:35
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__device__ void block_sync_lds()
Definition synchronization.hpp:16
PipelineVersion
Definition gridwise_gemm_pipeline_selector.hpp:18
@ v4
Definition gridwise_gemm_pipeline_selector.hpp:22
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
__host__ __device__ constexpr auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition dynamic_buffer.hpp:472
constexpr LoopScheduler make_default_loop_scheduler()
Definition loop_scheduler.hpp:20
unsigned char uint8_t
Definition stdint.h:124
Simple tile mapping which creates 3D grid of block of threads.
Definition block_to_ctile_map.hpp:977
Argument(const FloatA *p_a_grid_, const FloatB *p_b_grid_, FloatC *p_c_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_, index_t MPadded_, index_t NPadded_, index_t KPadded_, index_t K0Padded_, index_t k_batch_)
Definition gridwise_gemm_xdlops_splitk_lds_direct_load.hpp:141
index_t K
Definition gridwise_gemm_xdlops_splitk_lds_direct_load.hpp:131
index_t k_batch
Definition gridwise_gemm_xdlops_splitk_lds_direct_load.hpp:139
index_t StrideA
Definition gridwise_gemm_xdlops_splitk_lds_direct_load.hpp:132
const FloatA * p_a_grid
Definition gridwise_gemm_xdlops_splitk_lds_direct_load.hpp:126
void Print() const
Definition gridwise_gemm_xdlops_splitk_lds_direct_load.hpp:172
index_t StrideC
Definition gridwise_gemm_xdlops_splitk_lds_direct_load.hpp:134
index_t M
Definition gridwise_gemm_xdlops_splitk_lds_direct_load.hpp:129
FloatC * p_c_grid
Definition gridwise_gemm_xdlops_splitk_lds_direct_load.hpp:128
index_t MPadded
Definition gridwise_gemm_xdlops_splitk_lds_direct_load.hpp:135
index_t StrideB
Definition gridwise_gemm_xdlops_splitk_lds_direct_load.hpp:133
index_t NPadded
Definition gridwise_gemm_xdlops_splitk_lds_direct_load.hpp:136
const FloatB * p_b_grid
Definition gridwise_gemm_xdlops_splitk_lds_direct_load.hpp:127
index_t N
Definition gridwise_gemm_xdlops_splitk_lds_direct_load.hpp:130
index_t K0Padded
Definition gridwise_gemm_xdlops_splitk_lds_direct_load.hpp:138
index_t KPadded
Definition gridwise_gemm_xdlops_splitk_lds_direct_load.hpp:137
Definition gridwise_gemm_xdlops_splitk_lds_direct_load.hpp:99
__host__ static __device__ constexpr auto GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
Definition gridwise_gemm_xdlops_splitk_lds_direct_load.hpp:528
Definition utility/sequence.hpp:43
static constexpr value_type value
Definition utility/integral_constant.hpp:13
Definition device_base.hpp:197
Definition matrix_padder.hpp:134