device_batched_gemm_multi_d_xdl.hpp Source File

device_batched_gemm_multi_d_xdl.hpp Source File#

Composable Kernel: device_batched_gemm_multi_d_xdl.hpp Source File
device_batched_gemm_multi_d_xdl.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8
20
21namespace ck {
22namespace tensor_operation {
23namespace device {
24
25/*
26 * \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM.
27 *
28 * \tparam ComputePtrOffsetOfBatch Class that computes the base pointer offsets of A, B, C matrix
29 * given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly
30 * strided batched, but we can easily extend to other layouts. The returned offset can be either \p
31 * index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB
32 * limitations.
33 *
34 * \tparam Block2ETileMap Block2ETileMap::CalculateBottomIndex() takes in id of a workgroup and
35 * returns the 2D index of the tile that it computes. \see
36 * GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3::Run().
37 *
38 * \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2
39 * tiles from different matrices. Keep in mind that these 2 matrices can share the same grid
40 * descriptor (like in BatchedGEMM), or use their own grid descriptors (in GroupedGemm). \link
41 * impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for
42 * \link DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the
43 * computing of pointer offset into \p ComputePtrOffsetOfStridedBatch.
44 *
45 * \note \p Block2ETileMap allows customized mapping between a workgroup and the C-tile it computes.
46 * Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to
47 * realize BatchedGemm and GroupedGemm (and the corresponding GEMM fusion).
48 *
49 */
50template <typename GridwiseGemm,
51 typename ABDataType,
52 typename DsPointer,
53 typename EDataType,
54 typename AElementwiseOperation,
55 typename BElementwiseOperation,
56 typename CDEElementwiseOperation,
57 typename AGridDesc_AK0_M_AK1,
58 typename BGridDesc_BK0_N_BK1,
59 typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
60 typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
61 typename ComputePtrOffsetOfBatch,
62 typename Block2ETileMap,
63 bool HasMainKBlockLoop>
64__global__ void
65#if CK_USE_LAUNCH_BOUNDS
67#endif
68 kernel_batched_gemm_xdl(const ABDataType* __restrict__ p_a_grid,
69 const ABDataType* __restrict__ p_b_grid,
70 DsPointer p_ds_grid,
71 EDataType* __restrict__ p_e_grid,
72 const index_t batch_count,
73 const AElementwiseOperation a_element_op,
74 const BElementwiseOperation b_element_op,
75 const CDEElementwiseOperation cde_element_op,
76 const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1,
77 const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1,
78 const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
79 ds_grid_desc_mblock_mperblock_nblock_nperblock,
80 const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
81 e_grid_desc_mblock_mperblock_nblock_nperblock_,
82 const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
83 const Block2ETileMap block_2_etile_map)
84{
85#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
86 if constexpr(GridwiseGemm::template IsValidCompilationParameter<>())
87 {
88 const index_t num_blocks_per_batch =
89 __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
90 const index_t g_idx =
91 __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
92
93 const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
94 static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
95 const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
96 static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
97 const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane(
98 static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
99
100 const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
101
102 __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
103
104 DsPointer p_ds_grid_grp;
105
106 static constexpr index_t NumDTensor =
107 DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size();
108
110 [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; });
111
112 GridwiseGemm::template Run<HasMainKBlockLoop, InMemoryDataOperationEnum::Set>(
113 p_a_grid + a_batch_offset,
114 p_b_grid + b_batch_offset,
115 p_ds_grid_grp,
116 p_e_grid + e_batch_offset,
117 p_shared,
118 a_element_op,
119 b_element_op,
120 cde_element_op,
121 a_grid_desc_k0_m_k1,
122 b_grid_desc_k0_n_k1,
123 ds_grid_desc_mblock_mperblock_nblock_nperblock,
124 e_grid_desc_mblock_mperblock_nblock_nperblock_,
125 block_2_etile_map);
126 }
127#else
128 ignore = p_a_grid;
129 ignore = p_b_grid;
130 ignore = p_ds_grid;
131 ignore = p_e_grid;
132 ignore = batch_count;
133 ignore = a_grid_desc_k0_m_k1;
134 ignore = b_grid_desc_k0_n_k1;
135 ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock;
136 ignore = e_grid_desc_mblock_mperblock_nblock_nperblock_;
137 ignore = a_element_op;
138 ignore = b_element_op;
139 ignore = cde_element_op;
140 ignore = compute_ptr_offset_of_batch;
141 ignore = block_2_etile_map;
142#endif
143}
144
145template <typename ALayout,
146 typename BLayout,
147 typename DsLayout,
148 typename ELayout,
149 typename ADataType,
150 typename BDataType,
151 typename AccDataType,
152 typename CShuffleDataType,
153 typename DsDataType,
154 typename EDataType,
155 typename AElementwiseOperation,
156 typename BElementwiseOperation,
157 typename CDEElementwiseOperation,
158 GemmSpecialization GemmSpec,
159 index_t NumGemmKPrefetchStage,
160 index_t BlockSize,
161 index_t MPerBlock,
162 index_t NPerBlock,
163 index_t KPerBlock,
164 index_t AK1,
165 index_t BK1,
166 index_t MPerXDL,
167 index_t NPerXDL,
168 index_t MXdlPerWave,
169 index_t NXdlPerWave,
170 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
171 typename ABlockTransferThreadClusterArrangeOrder,
172 typename ABlockTransferSrcAccessOrder,
173 index_t ABlockTransferSrcVectorDim,
174 index_t ABlockTransferSrcScalarPerVector,
175 index_t ABlockTransferDstScalarPerVector_AK1,
176 bool ABlockLdsExtraM,
177 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
178 typename BBlockTransferThreadClusterArrangeOrder,
179 typename BBlockTransferSrcAccessOrder,
180 index_t BBlockTransferSrcVectorDim,
181 index_t BBlockTransferSrcScalarPerVector,
182 index_t BBlockTransferDstScalarPerVector_BK1,
183 bool BBlockLdsExtraN,
184 index_t CShuffleMXdlPerWavePerShuffle,
185 index_t CShuffleNXdlPerWavePerShuffle,
186 typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
187 index_t CDEBlockTransferScalarPerVector_NPerBlock,
190 BLayout,
191 DsLayout,
192 ELayout,
193 ADataType,
194 BDataType,
195 DsDataType,
196 EDataType,
197 AElementwiseOperation,
198 BElementwiseOperation,
199 CDEElementwiseOperation>
200{
202
204 static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
205 static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
206
207 static constexpr index_t NumDTensor = DsDataType::Size();
208
209 static constexpr auto I0 = Number<0>{};
210 static constexpr auto I1 = Number<1>{};
211 static constexpr auto I2 = Number<2>{};
212 static constexpr auto I3 = Number<3>{};
213
214 static constexpr auto matrix_padder =
215 MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
216
217 static auto MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA)
218 {
219 const auto a_grid_desc_mraw_kraw = [&]() {
221 {
222 return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
223 make_tuple(StrideA, I1));
224 }
226 {
227 return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
228 make_tuple(I1, StrideA));
229 }
230 }();
231
232 return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
233 }
234
235 static auto MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB)
236 {
237 const auto b_grid_desc_nraw_kraw = [&]() {
239 {
240 return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
241 make_tuple(I1, StrideB));
242 }
244 {
245 return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
246 make_tuple(StrideB, I1));
247 }
248 }();
249
250 return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
251 }
252
253 template <typename ELay>
254 static auto MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE)
255 {
256 const auto e_grid_desc_mraw_nraw = [&]() {
258 {
259 return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
260 make_tuple(StrideE, I1));
261 }
263 {
264 return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
265 make_tuple(I1, StrideE));
266 }
267 }();
268
269 return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw);
270 }
271
272 static auto MakeDsGridDescriptor_M_N(const std::array<index_t, NumDTensor>& MRaws,
273 const std::array<index_t, NumDTensor>& NRaws,
274 const std::array<index_t, NumDTensor>& DsStride)
275 {
276 return generate_tuple(
277 [&](auto i) {
278 using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
279
280 return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(MRaws[i], NRaws[i], DsStride[i]);
281 },
283 }
284
285 using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K(1, 1, 1));
286 using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K(1, 1, 1));
289
291 {
293 index_t BatchStrideB,
294 std::array<ck::index_t, NumDTensor> BatchStrideDs,
295 index_t BatchStrideE)
296 : BatchStrideA_(BatchStrideA),
297 BatchStrideB_(BatchStrideB),
298 BatchStrideDs_(BatchStrideDs),
299 BatchStrideE_(BatchStrideE)
300 {
301 }
302
303 __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
304 {
305 return g_idx * static_cast<long_index_t>(BatchStrideA_);
306 }
307
308 __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
309 {
310 return g_idx * static_cast<long_index_t>(BatchStrideB_);
311 }
312
313 __host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const
314 {
315 std::array<long_index_t, NumDTensor> ds_offset;
316 static_for<0, NumDTensor, 1>{}([&](auto i) {
317 ds_offset[i] = g_idx * static_cast<long_index_t>(BatchStrideDs_[i]);
318 });
319 return ds_offset;
320 }
321
322 __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const
323 {
324 return g_idx * static_cast<long_index_t>(BatchStrideE_);
325 }
326
327 private:
328 index_t BatchStrideA_;
329 index_t BatchStrideB_;
330 std::array<ck::index_t, NumDTensor> BatchStrideDs_;
331 index_t BatchStrideE_;
332 };
333
334 using ComputeDataType = ADataType;
335
336 template <index_t NXdlPerWave_>
338 ADataType, // TODO: distinguish A/B datatype
339 BDataType,
341 AccDataType,
342 CShuffleDataType,
343 DsDataType,
344 EDataType,
345 AElementwiseOperation,
346 BElementwiseOperation,
347 CDEElementwiseOperation,
348 NumGemmKPrefetchStage,
349 BlockSize,
350 MPerBlock,
351 NPerBlock,
352 KPerBlock,
353 AK1,
354 BK1,
355 MPerXDL,
356 NPerXDL,
357 MXdlPerWave,
358 NXdlPerWave_,
359 ABlockTransferThreadClusterLengths_AK0_M_AK1,
360 ABlockTransferThreadClusterArrangeOrder,
361 ABlockTransferSrcAccessOrder,
362 ABlockTransferSrcVectorDim,
363 ABlockTransferSrcScalarPerVector,
364 ABlockTransferDstScalarPerVector_AK1,
365 false,
366 ABlockLdsExtraM,
367 BBlockTransferThreadClusterLengths_BK0_N_BK1,
368 BBlockTransferThreadClusterArrangeOrder,
369 BBlockTransferSrcAccessOrder,
370 BBlockTransferSrcVectorDim,
371 BBlockTransferSrcScalarPerVector,
372 BBlockTransferDstScalarPerVector_BK1,
373 false,
374 BBlockLdsExtraN,
375 CShuffleMXdlPerWavePerShuffle,
376 CShuffleNXdlPerWavePerShuffle,
377 CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
378 CDEBlockTransferScalarPerVector_NPerBlock,
379 LoopSched>;
382
383 // desc for blockwise copy
386 AGridDesc_M_K{}))>;
389 BGridDesc_N_K{}))>;
392 DsGridDesc_M_N{}))>;
395 EGridDesc_M_N{}))>;
396
397 // block-to-e-tile map
400
401 // Argument
402 struct Argument : public BaseArgument
403 {
404 template <typename GridwiseGemm>
406 {
407 if(GridwiseGemm::CheckValidity(a_grid_desc_m_k_,
412 {
414 GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
416
418 GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
420 }
421 }
422 Argument(const void* p_a_grid,
423 const void* p_b_grid,
424 std::array<const void*, NumDTensor> p_ds_grid,
425 void* p_e_grid,
426 index_t MRaw,
427 index_t NRaw,
428 index_t KRaw,
429 index_t Batch,
430 index_t StrideA,
431 index_t StrideB,
432 const std::array<ck::index_t, NumDTensor>& StrideDs,
433 index_t StrideE,
434 index_t BatchStrideA,
435 index_t BatchStrideB,
436 const std::array<ck::index_t, NumDTensor>& BatchStrideDs,
437 index_t BatchStrideE,
438 AElementwiseOperation a_element_op,
439 BElementwiseOperation b_element_op,
440 CDEElementwiseOperation cde_element_op)
441 : p_a_grid_{static_cast<const ADataType*>(p_a_grid)},
442 p_b_grid_{static_cast<const BDataType*>(p_b_grid)},
443 p_ds_grid_{},
444 p_e_grid_{static_cast<EDataType*>(p_e_grid)},
445 Batch_(Batch),
449 e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<ELayout>(MRaw, NRaw, StrideE)},
451 GridwiseGemm64::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)},
453 GridwiseGemm64::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)},
456 compute_ptr_offset_of_batch_{BatchStrideA, BatchStrideB, BatchStrideDs, BatchStrideE},
457 block_2_etile_map_{GridwiseGemm64::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
458 a_element_op_{a_element_op},
459 b_element_op_{b_element_op},
460 cde_element_op_{cde_element_op}
461 {
462 // populate pointer, desc for Ds
463 static_for<0, NumDTensor, 1>{}([&](auto i) {
464 using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
465 using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
466
467 // D pointer
468 p_ds_grid_(i) = static_cast<const DDataType*>(p_ds_grid[i]);
469
470 // D desc
472 DeviceOp::MakeEGridDescriptor_M_N<DLayout>(MRaw, NRaw, StrideDs[i]);
473 });
474
475 // populate desc for Ds/E
476 if(get_warp_size() == 64)
477 {
478 if constexpr(NXdlPerWave64 > 0)
479 {
481 }
482 }
483 else
484 {
485 if constexpr(NXdlPerWave32 > 0)
486 {
488 }
489 }
490 }
491
492 void Print() const
493 {
494 std::cout << "A[M, K]: " << a_grid_desc_m_k_ << std::endl;
495 std::cout << "B[N, K]: " << b_grid_desc_n_k_ << std::endl;
497 [&](auto i) { std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i] << std::endl; });
498 std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl;
499 }
500
501 // private:
502 // pointers
503 const ADataType* p_a_grid_;
504 const BDataType* p_b_grid_;
506 EDataType* p_e_grid_;
507
508 // Batch
510
511 // tensor descriptors for problem definiton
516
517 // tensor descriptors for block/thread-wise copy
523
524 // for calculating batch offset
526
527 // block-to-e-tile map
529
530 // element-wise op
531 AElementwiseOperation a_element_op_;
532 BElementwiseOperation b_element_op_;
533 CDEElementwiseOperation cde_element_op_;
534 };
535
536 // Invoker
537 struct Invoker : public BaseInvoker
538 {
540
541 template <typename GridwiseGemm>
542 float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
543 {
544 if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
549 {
550 throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
551 }
552
553 const index_t grid_size =
554 arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_) * arg.Batch_;
555
556 const auto K =
557 arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
558
559 auto launch_kernel = [&](auto has_main_k_block_loop) {
560 constexpr bool has_main_loop = has_main_k_block_loop.value;
561
562 const auto kernel =
563 kernel_batched_gemm_xdl<GridwiseGemm,
564 ADataType, // TODO: distiguish A/B datatype
565 typename GridwiseGemm::DsGridPointer,
566 EDataType,
567 AElementwiseOperation,
568 BElementwiseOperation,
569 CDEElementwiseOperation,
574 ComputePtrOffsetOfStridedBatch,
576 has_main_loop>;
577
578 return launch_and_time_kernel(stream_config,
579 kernel,
580 dim3(grid_size),
581 dim3(BlockSize),
582 0,
583 arg.p_a_grid_,
584 arg.p_b_grid_,
585 arg.p_ds_grid_,
586 arg.p_e_grid_,
587 arg.Batch_,
588 arg.a_element_op_,
589 arg.b_element_op_,
590 arg.cde_element_op_,
597 };
598
599 if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
600 {
601 return launch_kernel(integral_constant<bool, true>{});
602 }
603 else
604 {
605 return launch_kernel(integral_constant<bool, false>{});
606 }
607 }
608
610
611 // polymorphic
612 float Run(const BaseArgument* p_arg,
613 const StreamConfig& stream_config = StreamConfig{}) override
614 {
615 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
616 }
617 };
618
619 static bool IsSupportedArgument(const Argument& arg)
620 {
622 {
623 return false;
624 }
625 if(get_warp_size() == 64)
626 {
627 if constexpr(NXdlPerWave64 > 0)
628 {
634 }
635 }
636 else
637 {
638 if constexpr(NXdlPerWave32 > 0)
639 {
645 }
646 }
647 return false;
648 }
649
650 // polymorphic
651 bool IsSupportedArgument(const BaseArgument* p_arg) override
652 {
653 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
654 }
655
656 static auto MakeArgument(const void* p_a,
657 const void* p_b,
658 const std::array<const void*, NumDTensor>& p_ds,
659 void* p_e,
660 index_t M,
661 index_t N,
662 index_t K,
663 index_t Batch,
664 index_t StrideA,
665 index_t StrideB,
666 const std::array<index_t, NumDTensor>& StrideDs,
667 index_t StrideE,
668 index_t BatchStrideA,
669 index_t BatchStrideB,
670 const std::array<ck::index_t, NumDTensor>& BatchStrideDs,
671 index_t BatchStrideE,
672 AElementwiseOperation a_element_op,
673 BElementwiseOperation b_element_op,
674 CDEElementwiseOperation cde_element_op)
675 {
676 return Argument{p_a,
677 p_b,
678 p_ds,
679 p_e,
680 M,
681 N,
682 K,
683 Batch,
684 StrideA,
685 StrideB,
686 StrideDs,
687 StrideE,
688 BatchStrideA,
689 BatchStrideB,
690 BatchStrideDs,
691 BatchStrideE,
692 a_element_op,
693 b_element_op,
694 cde_element_op};
695 }
696
697 static auto MakeInvoker() { return Invoker{}; }
698
699 // polymorphic
700 std::unique_ptr<BaseArgument>
701 MakeArgumentPointer(const void* p_a,
702 const void* p_b,
703 const std::array<const void*, NumDTensor>& p_ds,
704 void* p_e,
705 index_t M,
706 index_t N,
707 index_t K,
708 index_t Batch,
709 index_t StrideA,
710 index_t StrideB,
711 const std::array<ck::index_t, NumDTensor>& StrideDs,
712 index_t StrideE,
713 index_t BatchStrideA,
714 index_t BatchStrideB,
715 const std::array<ck::index_t, NumDTensor>& BatchStrideDs,
716 index_t BatchStrideE,
717 AElementwiseOperation a_element_op,
718 BElementwiseOperation b_element_op,
719 CDEElementwiseOperation cde_element_op) override
720 {
721 return std::make_unique<Argument>(p_a,
722 p_b,
723 p_ds,
724 p_e,
725 M,
726 N,
727 K,
728 Batch,
729 StrideA,
730 StrideB,
731 StrideDs,
732 StrideE,
733 BatchStrideA,
734 BatchStrideB,
735 BatchStrideDs,
736 BatchStrideE,
737 a_element_op,
738 b_element_op,
739 cde_element_op);
740 }
741
742 // polymorphic
743 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
744 {
745 return std::make_unique<Invoker>(Invoker{});
746 }
747
748 // polymorphic
749 std::string GetTypeString() const override
750 {
751 auto str = std::stringstream();
752
753 // clang-format off
754 str << "DeviceBatchedGemmMultiD_Xdl"
755 << "<"
756 << BlockSize << ", "
757 << MPerBlock << ", "
758 << NPerBlock << ", "
759 << KPerBlock << ", "
760 << AK1 << ", "
761 << BK1 << ", "
762 << getGemmSpecializationString(GemmSpec)
763 << ">";
764 // clang-format on
765
766 return str.str();
767 }
768};
769
770} // namespace device
771} // namespace tensor_operation
772} // namespace ck
#define CK_MIN_BLOCK_PER_CU
Definition ck.hpp:31
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
#define GET_NXDL_PER_WAVE_IMPL
Definition device_base.hpp:81
#define INVOKER_RUN_IMPL
Definition device_base.hpp:94
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
Definition convolution_backward_data_specialization.hpp:8
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition gemm_specialization.hpp:32
GemmSpecialization
Definition gemm_specialization.hpp:11
__global__ void kernel_batched_gemm_xdl(const ABDataType *__restrict__ p_a_grid, const ABDataType *__restrict__ p_b_grid, DsPointer p_ds_grid, EDataType *__restrict__ p_e_grid, const index_t batch_count, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CDEElementwiseOperation cde_element_op, const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1, const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1, const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock, const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, const Block2ETileMap block_2_etile_map)
Definition device_batched_gemm_multi_d_xdl.hpp:68
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
__device__ index_t get_grid_size()
Definition get_id.hpp:49
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
integral_constant< index_t, N > Number
Definition number.hpp:12
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
bool is_xdl_wmma_supported()
Definition host_utility/device_prop.hpp:76
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
constexpr bool is_same_v
Definition type.hpp:283
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
LoopScheduler
Definition loop_scheduler.hpp:15
int64_t long_index_t
Definition ck.hpp:300
constexpr LoopScheduler make_default_loop_scheduler()
Definition loop_scheduler.hpp:20
Definition ck/stream_config.hpp:10
Definition gridwise_gemm_multiple_d_xdl_cshuffle.hpp:78
static constexpr value_type value
Definition utility/integral_constant.hpp:13
Definition utility/integral_constant.hpp:20
Definition functional2.hpp:33
Definition device_base.hpp:197
Definition device_batched_gemm_multi_d_xdl.hpp:403
index_t Batch_
Definition device_batched_gemm_multi_d_xdl.hpp:509
EDataType * p_e_grid_
Definition device_batched_gemm_multi_d_xdl.hpp:506
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock_
Definition device_batched_gemm_multi_d_xdl.hpp:521
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_
Definition device_batched_gemm_multi_d_xdl.hpp:518
void Print() const
Definition device_batched_gemm_multi_d_xdl.hpp:492
const BDataType * p_b_grid_
Definition device_batched_gemm_multi_d_xdl.hpp:504
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_
Definition device_batched_gemm_multi_d_xdl.hpp:525
Block2ETileMap block_2_etile_map_
Definition device_batched_gemm_multi_d_xdl.hpp:528
DsGridDesc_M_N ds_grid_desc_m_n_
Definition device_batched_gemm_multi_d_xdl.hpp:514
CDEElementwiseOperation cde_element_op_
Definition device_batched_gemm_multi_d_xdl.hpp:533
void init_ds_e_grid_desc()
Definition device_batched_gemm_multi_d_xdl.hpp:405
const ADataType * p_a_grid_
Definition device_batched_gemm_multi_d_xdl.hpp:503
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_
Definition device_batched_gemm_multi_d_xdl.hpp:522
Argument(const void *p_a_grid, const void *p_b_grid, std::array< const void *, NumDTensor > p_ds_grid, void *p_e_grid, index_t MRaw, index_t NRaw, index_t KRaw, index_t Batch, index_t StrideA, index_t StrideB, const std::array< ck::index_t, NumDTensor > &StrideDs, index_t StrideE, index_t BatchStrideA, index_t BatchStrideB, const std::array< ck::index_t, NumDTensor > &BatchStrideDs, index_t BatchStrideE, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)
Definition device_batched_gemm_multi_d_xdl.hpp:422
GridwiseGemm64::DsGridPointer p_ds_grid_
Definition device_batched_gemm_multi_d_xdl.hpp:505
AElementwiseOperation a_element_op_
Definition device_batched_gemm_multi_d_xdl.hpp:531
EGridDesc_M_N e_grid_desc_m_n_
Definition device_batched_gemm_multi_d_xdl.hpp:515
AGridDesc_M_K a_grid_desc_m_k_
Definition device_batched_gemm_multi_d_xdl.hpp:512
BElementwiseOperation b_element_op_
Definition device_batched_gemm_multi_d_xdl.hpp:532
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_
Definition device_batched_gemm_multi_d_xdl.hpp:519
BGridDesc_N_K b_grid_desc_n_k_
Definition device_batched_gemm_multi_d_xdl.hpp:513
__host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const
Definition device_batched_gemm_multi_d_xdl.hpp:313
__host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
Definition device_batched_gemm_multi_d_xdl.hpp:308
ComputePtrOffsetOfStridedBatch(index_t BatchStrideA, index_t BatchStrideB, std::array< ck::index_t, NumDTensor > BatchStrideDs, index_t BatchStrideE)
Definition device_batched_gemm_multi_d_xdl.hpp:292
__host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
Definition device_batched_gemm_multi_d_xdl.hpp:303
__host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const
Definition device_batched_gemm_multi_d_xdl.hpp:322
Definition device_batched_gemm_multi_d_xdl.hpp:538
DeviceBatchedGemmMultiD_Xdl::Argument Argument
Definition device_batched_gemm_multi_d_xdl.hpp:539
INVOKER_RUN_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_batched_gemm_multi_d_xdl.hpp:612
float RunImp(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_batched_gemm_multi_d_xdl.hpp:542
Definition device_batched_gemm_multi_d_xdl.hpp:200
remove_cvref_t< decltype(GridwiseGemm64::MakeDefaultAGridDescriptor_AK0_M_AK1( AGridDesc_M_K{}))> AGridDesc_AK0_M_AK1
Definition device_batched_gemm_multi_d_xdl.hpp:384
static constexpr auto NXdlPerWave32
Definition device_batched_gemm_multi_d_xdl.hpp:205
static constexpr index_t NumDTensor
Definition device_batched_gemm_multi_d_xdl.hpp:207
static auto MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA)
Definition device_batched_gemm_multi_d_xdl.hpp:217
GridwiseGemmMultipleD_xdl_cshuffle< ADataType, BDataType, ComputeDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, LoopSched > GridwiseGemmBase
Definition device_batched_gemm_multi_d_xdl.hpp:337
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_batched_gemm_multi_d_xdl.hpp:380
static auto MakeInvoker()
Definition device_batched_gemm_multi_d_xdl.hpp:697
static constexpr auto I3
Definition device_batched_gemm_multi_d_xdl.hpp:212
remove_cvref_t< decltype(GridwiseGemm64::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))> Block2ETileMap
Definition device_batched_gemm_multi_d_xdl.hpp:398
remove_cvref_t< decltype(GridwiseGemm64::MakeDefaultBGridDescriptor_BK0_N_BK1( BGridDesc_N_K{}))> BGridDesc_BK0_N_BK1
Definition device_batched_gemm_multi_d_xdl.hpp:387
static constexpr auto I2
Definition device_batched_gemm_multi_d_xdl.hpp:211
ADataType ComputeDataType
Definition device_batched_gemm_multi_d_xdl.hpp:334
static auto MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE)
Definition device_batched_gemm_multi_d_xdl.hpp:254
static constexpr auto I1
Definition device_batched_gemm_multi_d_xdl.hpp:210
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, const std::array< const void *, NumDTensor > &p_ds, void *p_e, index_t M, index_t N, index_t K, index_t Batch, index_t StrideA, index_t StrideB, const std::array< ck::index_t, NumDTensor > &StrideDs, index_t StrideE, index_t BatchStrideA, index_t BatchStrideB, const std::array< ck::index_t, NumDTensor > &BatchStrideDs, index_t BatchStrideE, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op) override
Definition device_batched_gemm_multi_d_xdl.hpp:701
remove_cvref_t< decltype(MakeDsGridDescriptor_M_N({}, {}, {}))> DsGridDesc_M_N
Definition device_batched_gemm_multi_d_xdl.hpp:287
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_batched_gemm_multi_d_xdl.hpp:204
GridwiseGemmBase< NXdlPerWave32 > GridwiseGemm32
Definition device_batched_gemm_multi_d_xdl.hpp:381
static auto MakeArgument(const void *p_a, const void *p_b, const std::array< const void *, NumDTensor > &p_ds, void *p_e, index_t M, index_t N, index_t K, index_t Batch, index_t StrideA, index_t StrideB, const std::array< index_t, NumDTensor > &StrideDs, index_t StrideE, index_t BatchStrideA, index_t BatchStrideB, const std::array< ck::index_t, NumDTensor > &BatchStrideDs, index_t BatchStrideE, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)
Definition device_batched_gemm_multi_d_xdl.hpp:656
static auto MakeDsGridDescriptor_M_N(const std::array< index_t, NumDTensor > &MRaws, const std::array< index_t, NumDTensor > &NRaws, const std::array< index_t, NumDTensor > &DsStride)
Definition device_batched_gemm_multi_d_xdl.hpp:272
decltype(MakeBGridDescriptor_N_K(1, 1, 1)) BGridDesc_N_K
Definition device_batched_gemm_multi_d_xdl.hpp:286
decltype(MakeAGridDescriptor_M_K(1, 1, 1)) AGridDesc_M_K
Definition device_batched_gemm_multi_d_xdl.hpp:285
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_batched_gemm_multi_d_xdl.hpp:743
decltype(MakeEGridDescriptor_M_N< ELayout >(1, 1, 1)) EGridDesc_M_N
Definition device_batched_gemm_multi_d_xdl.hpp:288
static auto MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB)
Definition device_batched_gemm_multi_d_xdl.hpp:235
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_batched_gemm_multi_d_xdl.hpp:651
std::string GetTypeString() const override
Definition device_batched_gemm_multi_d_xdl.hpp:749
DeviceBatchedGemmMultiD_Xdl DeviceOp
Definition device_batched_gemm_multi_d_xdl.hpp:201
remove_cvref_t< decltype(GridwiseGemm64::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( EGridDesc_M_N{}))> EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
Definition device_batched_gemm_multi_d_xdl.hpp:393
static bool IsSupportedArgument(const Argument &arg)
Definition device_batched_gemm_multi_d_xdl.hpp:619
remove_cvref_t< decltype(GridwiseGemm64::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( DsGridDesc_M_N{}))> DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
Definition device_batched_gemm_multi_d_xdl.hpp:390
static constexpr auto I0
Definition device_batched_gemm_multi_d_xdl.hpp:209
static constexpr auto matrix_padder
Definition device_batched_gemm_multi_d_xdl.hpp:214
Definition device_batched_gemm_multi_d.hpp:27
Definition matrix_padder.hpp:180