gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp Source File

gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp Source File#

Composable Kernel: gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp Source File
gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
17
20
21namespace ck {
22
23// GEMM:
24// input : A[M, K]
25// input : B[N, K]
26// input : D0[M, N], D1[M, N], ...
27// output : E[M, N]
28// C = a_op(A) * b_op(B)
29// E = cde_op(C, D0, D1, ...)
30// Assume:
31// D0, D1, ... and E have the same layout
32template <typename ADataType,
33 typename BDataType,
34 typename AComputeType,
35 typename BComputeType,
36 typename AccDataType,
37 typename CShuffleDataType,
38 typename DsDataType,
39 typename EDataType,
40 typename AElementwiseOperation,
41 typename BElementwiseOperation,
42 typename CDEElementwiseOperation,
43 index_t NumGemmKPrefetchStage,
44 index_t BlockSize,
45 index_t MPerBlock,
46 index_t NPerBlock,
47 index_t KPerBlock,
48 index_t AK1Value,
49 index_t BK1Value,
50 index_t MPerXdl,
51 index_t NPerXdl,
52 index_t MXdlPerWave,
53 index_t NXdlPerWave,
54 typename ABlockTransferThreadClusterLengths_KBatch_AK0_M_AK1,
55 typename ABlockTransferThreadClusterArrangeOrder,
56 typename ABlockTransferSrcAccessOrder,
57 index_t ABlockTransferSrcVectorDim,
58 index_t ABlockTransferSrcScalarPerVector,
59 index_t ABlockTransferDstScalarPerVector_AK1,
60 bool AThreadTransferSrcResetCoordinateAfterRun,
61 index_t ABlockLdsExtraM,
62 typename BBlockTransferThreadClusterLengths_KBatch_BK0_N_BK1,
63 typename BBlockTransferThreadClusterArrangeOrder,
64 typename BBlockTransferSrcAccessOrder,
65 index_t BBlockTransferSrcVectorDim,
66 index_t BBlockTransferSrcScalarPerVector,
67 index_t BBlockTransferDstScalarPerVector_BK1,
68 bool BThreadTransferSrcResetCoordinateAfterRun,
69 index_t BBlockLdsExtraN,
70 index_t CShuffleMXdlPerWavePerShuffle,
71 index_t CShuffleNXdlPerWavePerShuffle,
72 typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
73 index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
74 LoopScheduler LoopSched,
75 PipelineVersion PipelineVer,
76 typename ALDSType,
77 typename BLDSType>
79{
80 static constexpr index_t NumDTensor = DsDataType::Size();
81
83
84 static constexpr auto I0 = Number<0>{};
85 static constexpr auto I1 = Number<1>{};
86 static constexpr auto I2 = Number<2>{};
87 static constexpr auto I3 = Number<3>{};
88 static constexpr auto I4 = Number<4>{};
89 static constexpr auto I5 = Number<5>{};
90 static constexpr auto I6 = Number<6>{};
91 static constexpr auto I7 = Number<7>{};
92
93 // K1 should be Number<...>
94 static constexpr auto AK1 = Number<AK1Value>{};
95 static constexpr auto BK1 = Number<BK1Value>{};
96 static constexpr auto AK0PerBlock = Number<KPerBlock / AK1Value>{};
97 static constexpr auto BK0PerBlock = Number<KPerBlock / BK1Value>{};
98
100
103
104 __host__ __device__ static constexpr auto GetABlockDescriptor_KBatch_AK0PerBlock_MPerBlock_AK1()
105 {
106 // A matrix in LDS memory, dst of blockwise copy
111 AK1,
112 I1));
113 }
114
115 __host__ __device__ static constexpr auto GetBBlockDescriptor_KBatch_BK0PerBlock_NPerBlock_BK1()
116 {
117 // B matrix in LDS memory, dst of blockwise copy
122 BK1,
123 I1));
124 }
125
126 __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
127 {
128 // A matrix in LDS memory, dst of blockwise copy
132 }
133
134 __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
135 {
136 // B matrix in LDS memory, dst of blockwise copy
140 }
141
142 __host__ __device__ static constexpr auto
144 {
145 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
146 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
147
148 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
152 I1,
154
155 return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
156 }
157
158 // ck::Tuple<const D0DataType*, const D1DataType*, ...>
159 static constexpr auto MakeDsGridPointer()
160 {
161 return generate_tuple(
162 [&](auto i) {
163 using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
164
165 return static_cast<const DDataType*>(nullptr);
166 },
168 }
169
170 __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
171 {
172 // LDS allocation for A and B: be careful of alignment
173 constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
174 constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
175
176 // lds max alignment
177 constexpr auto max_lds_align = math::lcm(AK1, BK1);
178
179 constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
180 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
181
182 constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
183 b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
184
185 // LDS allocation for C shuffle in LDS
186 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
188
189 constexpr auto c_block_size =
190 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
191
192 return math::max(a_block_space_size_aligned * sizeof(ALDSType) +
193 b_block_space_size_aligned * sizeof(BLDSType),
194 c_block_size * sizeof(CShuffleDataType));
195 }
196
197 __host__ __device__ static auto CalculateMPadded(index_t M)
198 {
199 return math::integer_least_multiple(M, MPerBlock);
200 }
201
202 __host__ __device__ static auto CalculateNPadded(index_t N)
203 {
204 return math::integer_least_multiple(N, NPerBlock);
205 }
206
207 __host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch)
208 {
209 return math::integer_least_multiple(K, KPerBlock * K_Batch);
210 }
211
212 template <typename ALayout, GemmSpecialization GemmSpec>
213 __host__ __device__ static auto
215 {
216 const auto a_grid_desc_m_k = [&]() {
218 {
219 return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
220 }
222 {
223 return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
224 }
225 }();
226
227 const auto MPad = CalculateMPadded(M);
228 const auto KPad = CalculateKPadded(K, KBatch);
229
230 const auto a_grid_desc_m_kpad = transform_tensor_descriptor(
231 a_grid_desc_m_k,
235
236 const auto AK0 = KPad / (KBatch * AK1);
237
242 {
243 // const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
245 a_grid_desc_m_kpad,
247 make_right_pad_transform(M, MPad - M)),
250 }
251 else
252 {
254 a_grid_desc_m_kpad,
259 }
260 }
261
262 template <typename BLayout, GemmSpecialization GemmSpec>
263 __host__ __device__ static auto
265 {
266 const auto b_grid_desc_k_n = [&]() {
268 {
269 return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1));
270 }
272 {
273 return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB));
274 }
275 }();
276
277 const auto NPad = CalculateNPadded(N);
278 const auto KPad = CalculateKPadded(K, KBatch);
279
280 const auto b_grid_desc_kpad_n = transform_tensor_descriptor(
281 b_grid_desc_k_n,
285
286 const auto BK0 = KPad / (KBatch * BK1);
287
292 {
293 // const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
295 b_grid_desc_kpad_n,
297 make_right_pad_transform(N, NPad - N)),
300 }
301 else
302 {
304 b_grid_desc_kpad_n,
309 }
310 }
311
312 // E desc for destination in blockwise copy
313 template <typename EGridDesc_M_N>
314 __host__ __device__ static constexpr auto
315 MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N& e_grid_desc_m_n)
316 {
317 const auto M = e_grid_desc_m_n.GetLength(I0);
318 const auto N = e_grid_desc_m_n.GetLength(I1);
319
320 const auto MBlock = M / MPerBlock;
321 const auto NBlock = N / NPerBlock;
322
323 const auto e_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
324 e_grid_desc_m_n,
329
330 return e_grid_desc_mblock_mperblock_nblock_nperblock;
331 }
332
333 // Ds desc for source in blockwise copy
334 template <typename DsGridDesc_M_N>
335 __host__ __device__ static constexpr auto
336 MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc_M_N& ds_grid_desc_m_n)
337 {
338 return generate_tuple(
339 [&](auto i) {
341 },
343 }
344
345 // return block_id to E matrix tile idx (m0, n0) mapping
346 template <typename EGridDesc_M_N>
347 __host__ __device__ static constexpr auto
348 MakeDefaultBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n)
349 {
351 e_grid_desc_m_n);
352 }
353
355
356 template <typename ALayout,
357 typename BLayout,
358 typename DsLayout,
359 typename ELayout,
360 GemmSpecialization GemmSpec>
361 __host__ __device__ static constexpr bool
363 const index_t N,
364 const index_t K,
365 const index_t StrideA,
366 const index_t StrideB,
367 const std::array<index_t, NumDTensor> StrideDs,
368 const index_t StrideE,
369 const index_t KBatch)
370 {
371 const auto a_grid_desc_kbatch_ak0_m_ak1 =
373 const auto b_grid_desc_kbatch_bk0_n_bk1 =
375
376 ignore = StrideDs;
377
378 const auto e_grid_desc_m_n = MakeEGridDescriptor_M_N<ELayout, GemmSpec>(M, N, StrideE);
379
380#if 0
381 // check tile size
382 if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0))
383 {
384 return false;
385 }
386#endif
387
388 // check gridwise gemm pipeline
389 const auto num_k_loop = K / KPerBlock;
390
391 if(!GridwiseGemmPipe::IsSupported(num_k_loop))
392 {
393 return false;
394 }
395
396 // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
397 // check tensor size: cannot be larger than 2GB each
398 constexpr long_index_t TwoGB = (long_index_t{1} << 31);
399
400 if(!(a_grid_desc_kbatch_ak0_m_ak1.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB &&
401 b_grid_desc_kbatch_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB &&
402 e_grid_desc_m_n.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB))
403 {
404 return false;
405 }
406
407 return true;
408 }
409
410 __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
411 {
412 const index_t num_loop = K / KPerBlock;
413
414 return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
415 }
416
417 using DsGridPointer = decltype(MakeDsGridPointer());
418
419 template <typename ELayout, GemmSpecialization GemmSpec>
420 __host__ __device__ static auto
422 {
423 constexpr auto matrix_padder =
425 MPerBlock, NPerBlock, KPerBlock};
426 const auto e_grid_desc_mraw_nraw = [&]() {
428 {
429 return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
430 make_tuple(StrideE, I1));
431 }
433 {
434 return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
435 make_tuple(I1, StrideE));
436 }
437 }();
438
439 return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw);
440 }
441
442 template <typename DsLayout, GemmSpecialization GemmSpec>
443 __host__ __device__ static auto
444 MakeDsGridDescriptor_M_N(const std::array<index_t, NumDTensor>& MRaws,
445 const std::array<index_t, NumDTensor>& NRaws,
446 const std::array<index_t, NumDTensor>& DsStride)
447 {
448 return generate_tuple(
449 [&](auto i) {
450 using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
451
452 return MakeEGridDescriptor_M_N<DLayout, GemmSpec>(MRaws[i], NRaws[i], DsStride[i]);
453 },
455 }
456
457 __device__ __host__ static constexpr auto GetMPerBlock() { return MPerBlock; }
458
459 template <bool HasMainKBlockLoop,
460 InMemoryDataOperationEnum EGlobalMemoryDataOperation,
461 index_t NumDTensor_,
462 typename DsDataType_,
463 bool Zeroing,
464 typename AGridDesc_KBatch_AK0_M_AK1,
465 typename BGridDesc_KBatch_BK0_N_BK1,
466 typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
467 typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
468 typename CDEElementwiseOperation_,
469 typename Block2ETileMap>
470 __device__ static void Run(const ADataType* __restrict__ p_a_grid,
471 const BDataType* __restrict__ p_b_grid,
472 DsGridPointer p_ds_grid,
473 EDataType* __restrict__ p_e_grid,
474 void* __restrict__ p_shared,
475 uint32_t* barrier_count_finished,
476 const index_t KBatch,
477 const AElementwiseOperation& a_element_op,
478 const BElementwiseOperation& b_element_op,
479 const CDEElementwiseOperation_& cde_element_op,
480 const AGridDesc_KBatch_AK0_M_AK1& a_grid_desc_kbatch_ak0_m_ak1,
481 const BGridDesc_KBatch_BK0_N_BK1& b_grid_desc_kbatch_bk0_n_bk1,
482 const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
483 ds_grid_desc_mblock_mperblock_nblock_nperblock,
484 const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
485 e_grid_desc_mblock_mperblock_nblock_nperblock,
486 const Block2ETileMap& block_2_etile_map)
487 {
488 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
489 p_a_grid, a_grid_desc_kbatch_ak0_m_ak1.GetElementSpaceSize());
490
491 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
492 p_b_grid, b_grid_desc_kbatch_bk0_n_bk1.GetElementSpaceSize());
493
494 const auto ds_grid_buf = generate_tuple(
495 [&](auto i) {
497 p_ds_grid[i],
498 ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize());
499 },
501
503 p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
504
505 // divide block work by [M, N]
506 const auto block_work_idx =
507 block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
508
509 // HACK: this force m/n_block_data_idx_on_grid into SGPR
510 const index_t kbatch_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
511
512 const index_t m_block_data_idx_on_grid =
513 __builtin_amdgcn_readfirstlane(block_work_idx[I1] * MPerBlock);
514
515 const index_t n_block_data_idx_on_grid =
516 __builtin_amdgcn_readfirstlane(block_work_idx[I2] * NPerBlock);
517
518 // lds max alignment
519 constexpr auto max_lds_align = math::lcm(AK1, BK1);
520
521 // A matrix in LDS memory, dst of blockwise copy
522 constexpr auto a_block_desc_kbatch_ak0_m_ak1 =
524
525 // B matrix in LDS memory, dst of blockwise copy
526 constexpr auto b_block_desc_kbatch_bk0_n_bk1 =
528
529 // A matrix blockwise copy
530 auto a_blockwise_copy =
532 AElementwiseOperation,
536 ABlockTransferThreadClusterLengths_KBatch_AK0_M_AK1,
537 ABlockTransferThreadClusterArrangeOrder,
538 ADataType,
539 ALDSType,
540 decltype(a_grid_desc_kbatch_ak0_m_ak1),
541 decltype(a_block_desc_kbatch_ak0_m_ak1),
542 ABlockTransferSrcAccessOrder,
544 ABlockTransferSrcVectorDim,
545 3,
546 ABlockTransferSrcScalarPerVector,
547 ABlockTransferDstScalarPerVector_AK1,
548 1,
549 1,
550 AThreadTransferSrcResetCoordinateAfterRun,
551 true,
552 NumGemmKPrefetchStage>(
553 a_grid_desc_kbatch_ak0_m_ak1,
554 make_multi_index(kbatch_id, 0, m_block_data_idx_on_grid, 0),
555 a_element_op,
556 a_block_desc_kbatch_ak0_m_ak1,
557 make_multi_index(0, 0, 0, 0),
559
560 // B matrix blockwise copy
561 auto b_blockwise_copy =
563 BElementwiseOperation,
567 BBlockTransferThreadClusterLengths_KBatch_BK0_N_BK1,
568 BBlockTransferThreadClusterArrangeOrder,
569 BDataType,
570 BLDSType,
571 decltype(b_grid_desc_kbatch_bk0_n_bk1),
572 decltype(b_block_desc_kbatch_bk0_n_bk1),
573 BBlockTransferSrcAccessOrder,
575 BBlockTransferSrcVectorDim,
576 3,
577 BBlockTransferSrcScalarPerVector,
578 BBlockTransferDstScalarPerVector_BK1,
579 1,
580 1,
581 BThreadTransferSrcResetCoordinateAfterRun,
582 true,
583 NumGemmKPrefetchStage>(
584 b_grid_desc_kbatch_bk0_n_bk1,
585 make_multi_index(kbatch_id, 0, n_block_data_idx_on_grid, 0),
586 b_element_op,
587 b_block_desc_kbatch_bk0_n_bk1,
588 make_multi_index(0, 0, 0, 0),
590
591 // A matrix in LDS memory, dst of blockwise copy
592 constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
593
594 // B matrix in LDS memory, dst of blockwise copy
595 constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
596
597 // GEMM definition
598 // c_mtx += transpose(a_mtx) * b_mtx
599 // a_mtx[K0PerBlock, MPerBlock] is in LDS
600 // b_mtx[K0PerBlock, NPerBlock] is in LDS
601 // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
602 // register
603 // sanity check
604 constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
605 constexpr bool is_single_rate_mfma =
607 lcm_AK1_BK1 <= 4) ||
608 (is_same<AComputeType, int8_t>::value && lcm_AK1_BK1 <= 8) ||
610 lcm_AK1_BK1 < 32))
611 ? true
612 : false;
613 constexpr auto is_scale_mfma = false;
614 constexpr index_t KPack = math::max(lcm_AK1_BK1,
615 MfmaSelector<AComputeType,
616 MPerXdl,
617 NPerXdl,
618 AComputeType,
619 is_single_rate_mfma,
620 is_scale_mfma>::selected_mfma.k_per_blk);
621
623 BlockSize,
624 ALDSType,
625 BLDSType,
626 AccDataType,
627 decltype(a_block_desc_ak0_m_ak1),
628 decltype(b_block_desc_bk0_n_bk1),
629 MPerXdl,
630 NPerXdl,
631 MXdlPerWave,
632 NXdlPerWave,
633 KPack,
634 LoopSched,
635 AComputeType,
636 BComputeType>();
637
638 if constexpr(Zeroing)
639 {
640 if(block_work_idx[I0] == 0)
641 {
642 const index_t nThreadSize = CDEShuffleBlockTransferScalarPerVector_NPerBlock;
643 const index_t numNThreads = NPerBlock / nThreadSize;
644 const index_t numMThreads = BlockSize / numNThreads;
645 const index_t mThreadSize = MPerBlock / numMThreads;
646
647 const index_t m_tid = get_thread_local_1d_id() / numNThreads;
648 const index_t n_tid = get_thread_local_1d_id() % numNThreads;
649
650 auto c_thread_desc_mblock_mperblock_nblock_nperblock =
653
655 EDataType,
656 c_thread_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(),
657 true>
658 e_thread_zero_buf;
659
660 auto c_thread_copy = ThreadwiseTensorSliceTransfer_v1r3<
661 EDataType,
662 EDataType,
663 decltype(c_thread_desc_mblock_mperblock_nblock_nperblock),
664 decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
668 3,
669 CDEShuffleBlockTransferScalarPerVector_NPerBlock,
671 1,
672 true>{e_grid_desc_mblock_mperblock_nblock_nperblock,
673 make_multi_index(block_work_idx[I1],
674 m_tid * mThreadSize,
675 block_work_idx[I2],
676 n_tid * nThreadSize),
678
679 c_thread_copy.Run(c_thread_desc_mblock_mperblock_nblock_nperblock,
680 make_tuple(I0, I0, I0, I0),
681 e_thread_zero_buf,
682 e_grid_desc_mblock_mperblock_nblock_nperblock,
683 e_grid_buf);
684
685 __builtin_amdgcn_s_barrier();
686
687 if(threadIdx.x == 0)
688 {
689 atomicAdd(barrier_count_finished, 1);
690 }
691 }
692 }
693
694 auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
695
696 // LDS allocation for A and B: be careful of alignment
697 constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
698 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
699
701 static_cast<ALDSType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
702
704 static_cast<BLDSType*>(p_shared) + a_block_space_size_aligned,
705 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
706
707 constexpr auto a_block_slice_copy_step = make_multi_index(0, KPerBlock / AK1, 0, 0);
708 constexpr auto b_block_slice_copy_step = make_multi_index(0, KPerBlock / BK1, 0, 0);
709
710 // gridwise GEMM pipeline
711 const auto gridwise_gemm_pipeline =
713
714 const index_t num_k_block_main_loop =
715 __builtin_amdgcn_readfirstlane((a_grid_desc_kbatch_ak0_m_ak1.GetLength(I1) *
716 a_grid_desc_kbatch_ak0_m_ak1.GetLength(I3)) /
717 KPerBlock);
718
719 gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(a_grid_desc_kbatch_ak0_m_ak1,
720 a_block_desc_kbatch_ak0_m_ak1,
721 a_blockwise_copy,
722 a_grid_buf,
723 a_block_buf,
724 a_block_slice_copy_step,
725 b_grid_desc_kbatch_bk0_n_bk1,
726 b_block_desc_kbatch_bk0_n_bk1,
727 b_blockwise_copy,
728 b_grid_buf,
729 b_block_buf,
730 b_block_slice_copy_step,
731 blockwise_gemm,
732 c_thread_buf,
733 num_k_block_main_loop);
734
735 // shuffle C and write out
736 {
737 if constexpr(Zeroing)
738 {
739 if(threadIdx.x == 0)
740 {
741 while(__atomic_load_n(barrier_count_finished, __ATOMIC_RELAXED) == 0) {}
742 }
743 __builtin_amdgcn_s_barrier();
744 }
745
746 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
747 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
748 "wrong!");
749
750 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
751 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
752
753 // TODO: hacky, fix it!
754 constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
755 blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
756
757 // TODO: hacky, fix it!
758 // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
759 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
760 blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
761
762 constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
763 constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
764 constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
765 constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
766 constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
767 constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
768 constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
769 constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
770
771 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
773
774 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
775 static_cast<CShuffleDataType*>(p_shared),
776 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
777
778 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
779 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
783 Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
784 M1, // M1 = MWave
785 M2, // M2 * M3 * M4 = MPerXdl
786 M3,
787 M4)),
790 Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
791 N1, // N1 = NWave
792 N2))), // N2 = NPerXdl
796
797 // calculate origin of thread output tensor on global memory
798 // blockwise GEMM c matrix starting index
799 const auto c_thread_mtx_on_block =
800 blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
801
802 const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
803 const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
804
805 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
807 make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
810
811 const auto m_thread_data_on_block_idx =
812 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
813 make_multi_index(m_thread_data_on_block));
814
815 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
820
821 const auto n_thread_data_on_block_idx =
822 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
823 make_multi_index(n_thread_data_on_block));
824
825 // shuffle: threadwise copy C from VGPR to LDS
826 auto c_thread_copy_vgpr_to_lds =
828 CShuffleDataType,
829 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
830 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
832 Sequence<CShuffleMXdlPerWavePerShuffle,
833 CShuffleNXdlPerWavePerShuffle,
834 I1,
835 I1,
836 M2,
837 I1,
838 M4,
839 I1>,
841 7,
842 1,
844 1,
845 true>{
846 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
848 0,
849 m_thread_data_on_block_idx[I1],
850 n_thread_data_on_block_idx[I1],
851 m_thread_data_on_block_idx[I2],
852 m_thread_data_on_block_idx[I3],
853 m_thread_data_on_block_idx[I4],
854 n_thread_data_on_block_idx[I2]),
856
857 // tuple of reference to C/Ds tensor descriptors
858 const auto c_ds_desc_refs = concat_tuple_of_reference(
859 tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
860 generate_tie([&](auto i) -> const auto& // return type should be reference
861 { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
863
864 // tuple of reference to C/Ds tensor descriptors
865 const auto c_ds_buf_refs = concat_tuple_of_reference(
866 tie(c_shuffle_block_buf),
867 generate_tie([&](auto i) -> const auto& // return type should be reference
868 { return ds_grid_buf[i]; },
870
871 // tuple of starting index of C/Ds blockwise copy
872 const auto idx_c_ds_block_begin = container_concat(
873 make_tuple(make_multi_index(0, 0, 0, 0)),
875 [&](auto) {
876 return make_multi_index(block_work_idx[I1], 0, block_work_idx[I2], 0);
877 },
879
880 // space filling curve for threadwise C in VGPR before shuffle
881 constexpr auto sfc_c_vgpr =
884 Sequence<CShuffleMXdlPerWavePerShuffle,
885 CShuffleNXdlPerWavePerShuffle,
886 1,
887 1,
888 M2,
889 1,
890 M4,
891 1>>{};
892
893 // space filling curve for shuffled blockwise C/D/E
894 constexpr auto sfc_cde_block =
897 Sequence<1,
898 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
899 1,
900 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
901
902 constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
903
904 static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
905
906 // blockwise copy C/D/E between LDS and global
907 auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7<
909 decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType_{})),
911 decltype(c_ds_desc_refs),
912 decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
913 CDEElementwiseOperation_,
914 Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make
915 // Sequence support
916 // arbitray type
917 Sequence<1,
918 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
919 1,
920 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
921 CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
922 Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
923 Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
924 3, // index_t VectorDim,
925 CDEShuffleBlockTransferScalarPerVector_NPerBlock,
928 uniform_sequence_gen_t<NumDTensor_,
929 false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
930 Sequence<false>> // ThreadTransferDstResetCoordinateAfterRunFlags
931 {c_ds_desc_refs,
932 idx_c_ds_block_begin,
933 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
934 make_tuple(make_multi_index(block_work_idx[I1], 0, block_work_idx[I2], 0)),
935 cde_element_op};
936
937 static_for<0, num_access, 1>{}([&](auto access_id) {
938 // make sure it's safe to write to LDS
940
941 // each thread write its data from VGPR to LDS
942 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
943 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
944 c_thread_buf,
945 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
946 c_shuffle_block_buf);
947
948 // make sure it's safe to read from LDS
950
951 // each block copy its data from LDS to global
952 cde_block_copy_lds_and_global.Run(
953 c_ds_desc_refs,
954 c_ds_buf_refs,
955 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
956 tie(e_grid_buf));
957
958 if constexpr(access_id < num_access - 1)
959 {
960 constexpr auto cde_lds_and_global_step =
961 sfc_cde_block.GetForwardStep(access_id);
962
963 // move on Ds
964 static_for<0, NumDTensor_, 1>{}([&](auto i) {
965 cde_block_copy_lds_and_global.MoveSrcSliceWindow(
966 c_ds_desc_refs, i + I1, cde_lds_and_global_step);
967 });
968
969 // move on E
970 cde_block_copy_lds_and_global.MoveDstSliceWindow(
971 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
972 I0,
973 cde_lds_and_global_step);
974 }
975 });
976
977 if constexpr(Zeroing)
978 {
979 if(threadIdx.x == 0)
980 {
981 index_t k_id_finished_t = atomicAdd(barrier_count_finished, 1);
982
983 if(k_id_finished_t == KBatch)
984 {
985 *barrier_count_finished = 0;
986 }
987 }
988 }
989 }
990 }
991
992 template <bool HasMainKBlockLoop,
993 InMemoryDataOperationEnum EGlobalMemoryDataOperation,
994 GemmSpecialization GemmSpec,
995 typename ALayout,
996 typename BLayout,
997 typename DsLayout,
998 typename ELayout,
999 typename Block2ETileMap>
1000 __device__ static void RunWithZeroing(const void* __restrict__ p_a_grid_,
1001 const void* __restrict__ p_b_grid_,
1002 DsGridPointer p_ds_grid,
1003 void* __restrict__ p_e_grid_,
1004 void* __restrict__ p_shared,
1005 uint32_t* barrier_count_finished,
1006 const AElementwiseOperation& a_element_op,
1007 const BElementwiseOperation& b_element_op,
1008 const CDEElementwiseOperation& cde_element_op,
1009 const index_t M,
1010 const index_t N,
1011 const index_t K,
1012 const index_t StrideA,
1013 const index_t StrideB,
1014 const std::array<index_t, NumDTensor> StrideDs,
1015 const index_t StrideE,
1016 const index_t KBatch,
1017 const Block2ETileMap& block_2_etile_map)
1018 {
1019 const auto p_a_grid = reinterpret_cast<const ADataType*>(p_a_grid_);
1020 const auto p_b_grid = reinterpret_cast<const BDataType*>(p_b_grid_);
1021 const auto p_e_grid = reinterpret_cast<EDataType*>(p_e_grid_);
1022
1023 using DsGridDesc_M_N =
1025
1026 DsGridDesc_M_N ds_grid_desc_m_n;
1027
1028 static_for<0, NumDTensor, 1>{}([&](auto j) {
1029 using DLayout = remove_cvref_t<tuple_element_t<j.value, DsLayout>>;
1030
1031 ds_grid_desc_m_n(j) = MakeEGridDescriptor_M_N<DLayout, GemmSpec>(M, N, StrideDs[j]);
1032 });
1033
1034 const auto e_grid_desc_m_n = MakeEGridDescriptor_M_N<ELayout, GemmSpec>(M, N, StrideE);
1035
1036 // tensor descriptors for block/thread-wise copy
1037 const auto a_grid_desc_kbatch_ak0_m_ak1 =
1039
1040 const auto b_grid_desc_kbatch_bk0_n_bk1 =
1042
1043 using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
1045 DsGridDesc_M_N{}))>;
1046
1047 DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock;
1048
1049 static_for<0, NumDTensor, 1>{}([&](auto j) {
1050 ds_grid_desc_mblock_mperblock_nblock_nperblock(j) =
1052 });
1053
1054 const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
1056
1057 const auto block_work_idx =
1058 block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
1059
1060 const index_t kbatch_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
1061
1062 if(kbatch_id == KBatch - 1)
1063 {
1065 p_a_grid,
1066 p_b_grid,
1067 p_ds_grid,
1068 p_e_grid,
1069 p_shared,
1070 barrier_count_finished,
1071 KBatch,
1072 a_element_op,
1073 b_element_op,
1074 cde_element_op,
1075 a_grid_desc_kbatch_ak0_m_ak1,
1076 b_grid_desc_kbatch_bk0_n_bk1,
1077 ds_grid_desc_mblock_mperblock_nblock_nperblock,
1078 e_grid_desc_mblock_mperblock_nblock_nperblock,
1079 block_2_etile_map);
1080 }
1081 else
1082 {
1084 p_a_grid,
1085 p_b_grid,
1086 p_ds_grid,
1087 p_e_grid,
1088 p_shared,
1089 barrier_count_finished,
1090 KBatch,
1091 a_element_op,
1092 b_element_op,
1094 a_grid_desc_kbatch_ak0_m_ak1,
1095 b_grid_desc_kbatch_bk0_n_bk1,
1096 ds_grid_desc_mblock_mperblock_nblock_nperblock,
1097 e_grid_desc_mblock_mperblock_nblock_nperblock,
1098 block_2_etile_map);
1099 }
1100 }
1101
1102 template <bool HasMainKBlockLoop,
1103 InMemoryDataOperationEnum EGlobalMemoryDataOperation,
1104 GemmSpecialization GemmSpec,
1105 typename ALayout,
1106 typename BLayout,
1107 typename DsLayout,
1108 typename ELayout,
1109 typename Block2ETileMap>
1110 __device__ static void Run(const void* __restrict__ p_a_grid_,
1111 const void* __restrict__ p_b_grid_,
1112 DsGridPointer p_ds_grid,
1113 void* __restrict__ p_e_grid_,
1114 void* __restrict__ p_shared,
1115 uint32_t*,
1116 const AElementwiseOperation& a_element_op,
1117 const BElementwiseOperation& b_element_op,
1118 const CDEElementwiseOperation& cde_element_op,
1119 const index_t M,
1120 const index_t N,
1121 const index_t K,
1122 const index_t StrideA,
1123 const index_t StrideB,
1124 const std::array<index_t, NumDTensor> StrideDs,
1125 const index_t StrideE,
1126 const index_t KBatch,
1127 const Block2ETileMap& block_2_etile_map)
1128 {
1129 const auto p_a_grid = reinterpret_cast<const ADataType*>(p_a_grid_);
1130 const auto p_b_grid = reinterpret_cast<const BDataType*>(p_b_grid_);
1131 const auto p_e_grid = reinterpret_cast<EDataType*>(p_e_grid_);
1132
1133 using DsGridDesc_M_N =
1135
1136 DsGridDesc_M_N ds_grid_desc_m_n;
1137
1138 static_for<0, NumDTensor, 1>{}([&](auto j) {
1139 using DLayout = remove_cvref_t<tuple_element_t<j.value, DsLayout>>;
1140
1141 ds_grid_desc_m_n(j) = MakeEGridDescriptor_M_N<DLayout, GemmSpec>(M, N, StrideDs[j]);
1142 });
1143
1144 const auto e_grid_desc_m_n = MakeEGridDescriptor_M_N<ELayout, GemmSpec>(M, N, StrideE);
1145
1146 // tensor descriptors for block/thread-wise copy
1147 const auto a_grid_desc_kbatch_ak0_m_ak1 =
1149
1150 const auto b_grid_desc_kbatch_bk0_n_bk1 =
1152
1153 using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
1155 DsGridDesc_M_N{}))>;
1156
1157 DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock;
1158
1159 static_for<0, NumDTensor, 1>{}([&](auto j) {
1160 ds_grid_desc_mblock_mperblock_nblock_nperblock(j) =
1162 });
1163
1164 const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
1166
1168 p_a_grid,
1169 p_b_grid,
1170 p_ds_grid,
1171 p_e_grid,
1172 p_shared,
1173 nullptr,
1174 KBatch,
1175 a_element_op,
1176 b_element_op,
1177 cde_element_op,
1178 a_grid_desc_kbatch_ak0_m_ak1,
1179 b_grid_desc_kbatch_bk0_n_bk1,
1180 ds_grid_desc_mblock_mperblock_nblock_nperblock,
1181 e_grid_desc_mblock_mperblock_nblock_nperblock,
1182 block_2_etile_map);
1183 }
1184};
1185
1186} // namespace ck
#define IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType_)
Definition device_base.hpp:178
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
Definition utility/math.hpp:78
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr auto lcm(X x, Y y)
Definition utility/math.hpp:198
GemmSpecialization
Definition gemm_specialization.hpp:11
@ MKPadding
Definition gemm_specialization.hpp:18
@ NPadding
Definition gemm_specialization.hpp:15
@ MPadding
Definition gemm_specialization.hpp:14
@ MNKPadding
Definition gemm_specialization.hpp:20
@ MNPadding
Definition gemm_specialization.hpp:17
@ NKPadding
Definition gemm_specialization.hpp:19
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition utility/sequence.hpp:928
constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector()
Definition blockwise_gemm_xdlops.hpp:620
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
__host__ __device__ constexpr auto container_concat(const X &x, const Ys &... ys)
Definition utility/container_helper.hpp:320
constexpr auto GridwiseGemmPipeline_Selector()
Definition gridwise_gemm_pipeline_selector.hpp:31
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition ck.hpp:277
@ Set
Definition ck.hpp:278
__host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition tensor_description/tensor_adaptor.hpp:425
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
__host__ __device__ constexpr auto make_freeze_transform(const LowerIndex &low_idx)
Definition multi_index_transform_helper.hpp:151
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition utility/tuple.hpp:218
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
integral_constant< index_t, N > Number
Definition number.hpp:12
@ Vgpr
Definition amd_address_space.hpp:20
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
LoopScheduler
Definition loop_scheduler.hpp:15
__device__ index_t get_thread_local_1d_id()
Definition get_id.hpp:41
typename sequence_merge< Sx, Sy >::type sequence_merge_t
Definition utility/sequence.hpp:925
int64_t long_index_t
Definition ck.hpp:300
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__device__ void block_sync_lds()
Definition synchronization.hpp:16
PipelineVersion
Definition gridwise_gemm_pipeline_selector.hpp:18
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
__host__ __device__ constexpr auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition dynamic_buffer.hpp:472
__host__ __device__ constexpr auto generate_tie(F &&f, Number< N >)
Definition tuple_helper.hpp:34
__host__ __device__ constexpr auto concat_tuple_of_reference(const Tuple< X &... > &tx, const Tuple< Y &... > &ty)
Definition tuple_helper.hpp:42
unsigned int uint32_t
Definition stdint.h:126
Definition block_to_ctile_map.hpp:261
Definition gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp:79
static __device__ void RunWithZeroing(const void *__restrict__ p_a_grid_, const void *__restrict__ p_b_grid_, DsGridPointer p_ds_grid, void *__restrict__ p_e_grid_, void *__restrict__ p_shared, uint32_t *barrier_count_finished, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CDEElementwiseOperation &cde_element_op, const index_t M, const index_t N, const index_t K, const index_t StrideA, const index_t StrideB, const std::array< index_t, NumDTensor > StrideDs, const index_t StrideE, const index_t KBatch, const Block2ETileMap &block_2_etile_map)
Definition gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp:1000
__host__ static __device__ auto MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE)
Definition gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp:421
static __device__ void Run(const ADataType *__restrict__ p_a_grid, const BDataType *__restrict__ p_b_grid, DsGridPointer p_ds_grid, EDataType *__restrict__ p_e_grid, void *__restrict__ p_shared, uint32_t *barrier_count_finished, const index_t KBatch, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CDEElementwiseOperation_ &cde_element_op, const AGridDesc_KBatch_AK0_M_AK1 &a_grid_desc_kbatch_ak0_m_ak1, const BGridDesc_KBatch_BK0_N_BK1 &b_grid_desc_kbatch_bk0_n_bk1, const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &ds_grid_desc_mblock_mperblock_nblock_nperblock, const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &e_grid_desc_mblock_mperblock_nblock_nperblock, const Block2ETileMap &block_2_etile_map)
Definition gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp:470
static __device__ void Run(const void *__restrict__ p_a_grid_, const void *__restrict__ p_b_grid_, DsGridPointer p_ds_grid, void *__restrict__ p_e_grid_, void *__restrict__ p_shared, uint32_t *, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CDEElementwiseOperation &cde_element_op, const index_t M, const index_t N, const index_t K, const index_t StrideA, const index_t StrideB, const std::array< index_t, NumDTensor > StrideDs, const index_t StrideE, const index_t KBatch, const Block2ETileMap &block_2_etile_map)
Definition gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp:1110
Selects the appropriate MFMA instruction type and configuration for given data types and tile sizes o...
Definition xdlops_gemm.hpp:1208
Definition utility/sequence.hpp:43
Definition tensor_space_filling_curve.hpp:20
Definition static_buffer.hpp:16
Blockwise data transfer.
Definition thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition thread_group_tensor_slice_transfer_v7.hpp:42
Definition threadwise_tensor_slice_transfer.hpp:39
Definition utility/tuple.hpp:117
static constexpr value_type value
Definition utility/integral_constant.hpp:13
Definition functional2.hpp:33
__host__ __device__ constexpr auto PadCDescriptor_M_N(const CDesc_MRaw_NRaw &c_desc_mraw_nraw) const
Definition matrix_padder.hpp:163
Definition matrix_padder.hpp:180
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340