device_gemm_multiple_d_xdl_cshuffle.hpp Source File

device_gemm_multiple_d_xdl_cshuffle.hpp Source File#

Composable Kernel: device_gemm_multiple_d_xdl_cshuffle.hpp Source File
device_gemm_multiple_d_xdl_cshuffle.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#ifndef __HIPCC_RTC__
7#include <iostream>
8#include <sstream>
11#endif
12
21
22namespace ck {
23
24template <typename GridwiseGemm,
25 typename ADataType,
26 typename BDataType,
27 typename DsPointer,
28 typename EDataType,
29 typename AElementwiseOperation,
30 typename BElementwiseOperation,
31 typename CDEElementwiseOperation,
32 typename AGridDesc_AK0_M_AK1,
33 typename BGridDesc_BK0_N_BK1,
34 typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
35 typename EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
36 typename Block2ETileMap,
37 bool HasMainKBlockLoop>
38__global__ void
39#if CK_USE_LAUNCH_BOUNDS
41#endif
42 kernel_gemm_multiple_d_xdl_cshuffle(const ADataType* __restrict__ p_a_grid,
43 const BDataType* __restrict__ p_b_grid,
44 DsPointer p_ds_grid,
45 EDataType* __restrict__ p_e_grid,
46 const AElementwiseOperation a_element_op,
47 const BElementwiseOperation b_element_op,
48 const CDEElementwiseOperation cde_element_op,
49 const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
50 const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
51 const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
52 ds_grid_desc_mblock_mperblock_nblock_nperblock,
53 const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
54 e_grid_desc_mblock_mperblock_nblock_nperblock,
55 const Block2ETileMap block_2_etile_map)
56{
57#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
58 if constexpr(GridwiseGemm::template IsValidCompilationParameter<>())
59 {
60 __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
61
62 GridwiseGemm::template Run<HasMainKBlockLoop, InMemoryDataOperationEnum::Set>(
63 p_a_grid,
64 p_b_grid,
65 p_ds_grid,
66 p_e_grid,
67 p_shared,
68 a_element_op,
69 b_element_op,
70 cde_element_op,
71 a_grid_desc_ak0_m_ak1,
72 b_grid_desc_bk0_n_bk1,
73 ds_grid_desc_mblock_mperblock_nblock_nperblock,
74 e_grid_desc_mblock_mperblock_nblock_nperblock,
75 block_2_etile_map);
76 }
77#else
78 ignore = p_a_grid;
79 ignore = p_b_grid;
80 ignore = p_ds_grid;
81 ignore = p_e_grid;
82 ignore = a_element_op;
83 ignore = b_element_op;
84 ignore = cde_element_op;
85 ignore = a_grid_desc_ak0_m_ak1;
86 ignore = b_grid_desc_bk0_n_bk1;
87 ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock;
88 ignore = e_grid_desc_mblock_mperblock_nblock_nperblock;
89 ignore = block_2_etile_map;
90#endif
91}
92
93} // namespace ck
94
95namespace ck {
96namespace tensor_operation {
97namespace device {
98
99// GEMM:
100// input : A[M, K]
101// input : B[N, K]
102// input : D0[M, N], D1[M, N], ...
103// output : E[M, N]
104// C = a_op(A) * b_op(B)
105// E = cde_op(C, D0, D1, ...)
106// Assume:
107// D0, D1, ... and E have the same layout
108template <typename ALayout,
109 typename BLayout,
110 typename DsLayout,
111 typename ELayout,
112 typename ADataType,
113 typename BDataType,
114 typename AccDataType,
115 typename CShuffleDataType,
116 typename DsDataType,
117 typename EDataType,
118 typename AElementwiseOperation,
119 typename BElementwiseOperation,
120 typename CDEElementwiseOperation,
121 GemmSpecialization GemmSpec,
122 index_t NumGemmKPrefetchStage,
123 index_t BlockSize,
124 index_t MPerBlock,
125 index_t NPerBlock,
126 index_t KPerBlock,
127 index_t AK1,
128 index_t BK1,
129 index_t MPerXDL,
130 index_t NPerXDL,
131 index_t MXdlPerWave,
132 index_t NXdlPerWave,
133 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
134 typename ABlockTransferThreadClusterArrangeOrder,
135 typename ABlockTransferSrcAccessOrder,
136 index_t ABlockTransferSrcVectorDim,
137 index_t ABlockTransferSrcScalarPerVector,
138 index_t ABlockTransferDstScalarPerVector_AK1,
139 index_t ABlockLdsExtraM,
140 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
141 typename BBlockTransferThreadClusterArrangeOrder,
142 typename BBlockTransferSrcAccessOrder,
143 index_t BBlockTransferSrcVectorDim,
144 index_t BBlockTransferSrcScalarPerVector,
145 index_t BBlockTransferDstScalarPerVector_BK1,
146 index_t BBlockLdsExtraN,
147 index_t CShuffleMXdlPerWavePerShuffle,
148 index_t CShuffleNXdlPerWavePerShuffle,
149 typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
150 index_t CDEBlockTransferScalarPerVector_NPerBlock,
153 typename ComputeDataType = EDataType>
155 BLayout,
156 DsLayout,
157 ELayout,
158 ADataType,
159 BDataType,
160 DsDataType,
161 EDataType,
162 AElementwiseOperation,
163 BElementwiseOperation,
164 CDEElementwiseOperation>
165{
168 static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
169 static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
170
171 static constexpr index_t NumDTensor = DsDataType::Size();
172
173 static constexpr auto I0 = Number<0>{};
174 static constexpr auto I1 = Number<1>{};
175 static constexpr auto I2 = Number<2>{};
176 static constexpr auto I3 = Number<3>{};
177
178 static constexpr auto matrix_padder =
179 MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
180
181 static auto MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA)
182 {
183 const auto a_grid_desc_mraw_kraw = [&]() {
185 {
186 return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
187 make_tuple(StrideA, I1));
188 }
190 {
191 return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
192 make_tuple(I1, StrideA));
193 }
194 }();
195
196 return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
197 }
198
199 static auto MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB)
200 {
201 const auto b_grid_desc_nraw_kraw = [&]() {
203 {
204 return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
205 make_tuple(I1, StrideB));
206 }
208 {
209 return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
210 make_tuple(StrideB, I1));
211 }
212 }();
213
214 return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
215 }
216
217 template <typename ELay>
218 static auto MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE)
219 {
220 const auto e_grid_desc_mraw_nraw = [&]() {
222 {
223 return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
224 make_tuple(StrideE, I1));
225 }
227 {
228 return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
229 make_tuple(I1, StrideE));
230 }
231 }();
232
233 return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw);
234 }
235
237 const Array<index_t, NumDTensor>& NRaws,
238 const Array<index_t, NumDTensor>& DsStride)
239 {
240 return generate_tuple(
241 [&](auto i) {
242 using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
243
244 return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(MRaws[i], NRaws[i], DsStride[i]);
245 },
247 }
248
249 // desc for problem definition
250 using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K(1, 1, 1));
251 using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K(1, 1, 1));
254
255 // GridwiseGemm
256 template <index_t NXdlPerWave_>
258 ADataType,
259 BDataType,
260 ComputeDataType,
261 AccDataType,
262 CShuffleDataType,
263 DsDataType,
264 EDataType,
265 AElementwiseOperation,
266 BElementwiseOperation,
267 CDEElementwiseOperation,
268 NumGemmKPrefetchStage,
269 BlockSize,
270 MPerBlock,
271 NPerBlock,
272 KPerBlock,
273 AK1,
274 BK1,
275 MPerXDL,
276 NPerXDL,
277 MXdlPerWave,
278 NXdlPerWave_,
279 ABlockTransferThreadClusterLengths_AK0_M_AK1,
280 ABlockTransferThreadClusterArrangeOrder,
281 ABlockTransferSrcAccessOrder,
282 ABlockTransferSrcVectorDim,
283 ABlockTransferSrcScalarPerVector,
284 ABlockTransferDstScalarPerVector_AK1,
285 false,
286 ABlockLdsExtraM,
287 BBlockTransferThreadClusterLengths_BK0_N_BK1,
288 BBlockTransferThreadClusterArrangeOrder,
289 BBlockTransferSrcAccessOrder,
290 BBlockTransferSrcVectorDim,
291 BBlockTransferSrcScalarPerVector,
292 BBlockTransferDstScalarPerVector_BK1,
293 false,
294 BBlockLdsExtraN,
295 CShuffleMXdlPerWavePerShuffle,
296 CShuffleNXdlPerWavePerShuffle,
297 CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
298 CDEBlockTransferScalarPerVector_NPerBlock,
299 LoopSched,
300 PipelineVer>;
303
304 // desc for blockwise copy
307 AGridDesc_M_K{}))>;
310 BGridDesc_N_K{}))>;
313 DsGridDesc_M_N{}))>;
316 EGridDesc_M_N{}))>;
317
318 // block-to-e-tile map
321
322#ifndef __HIPCC_RTC__
323 // Argument
324 struct Argument : public BaseArgument
325 {
326 Argument(const void* p_a_grid,
327 const void* p_b_grid,
328 std::array<const void*, NumDTensor> p_ds_grid,
329 void* p_e_grid,
330 index_t MRaw,
331 index_t NRaw,
332 index_t KRaw,
333 index_t StrideA,
334 index_t StrideB,
335 std::array<index_t, NumDTensor> StrideDs,
336 index_t StrideE,
337 AElementwiseOperation a_element_op,
338 BElementwiseOperation b_element_op,
339 CDEElementwiseOperation cde_element_op)
340 : p_a_grid_{static_cast<const ADataType*>(p_a_grid)},
341 p_b_grid_{static_cast<const BDataType*>(p_b_grid)},
342 p_ds_grid_{},
343 p_e_grid_{static_cast<EDataType*>(p_e_grid)},
347 e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<ELayout>(MRaw, NRaw, StrideE)},
349 GridwiseGemm64::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)},
351 GridwiseGemm64::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)},
352 block_2_etile_map_{GridwiseGemm64::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
353 a_element_op_{a_element_op},
354 b_element_op_{b_element_op},
355 cde_element_op_{cde_element_op},
356 MRaw_{MRaw},
357 NRaw_{NRaw},
358 KRaw_{KRaw}
359 {
360 // populate pointer, desc for Ds
361 static_for<0, NumDTensor, 1>{}([&](auto i) {
362 using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
363 using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
364
365 // D pointer
366 p_ds_grid_(i) = static_cast<const DDataType*>(p_ds_grid[i]);
367
368 // D desc
370 DeviceOp::MakeEGridDescriptor_M_N<DLayout>(MRaw, NRaw, StrideDs[i]);
371 });
372 }
373
374 void Print() const
375 {
376 std::cout << "A[M, K]: " << a_grid_desc_m_k_ << std::endl;
377 std::cout << "B[N, K]: " << b_grid_desc_n_k_ << std::endl;
379 [&](auto i) { std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i] << std::endl; });
380 std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl;
381 }
382
383 // private:
384 // pointers
385 const ADataType* p_a_grid_;
386 const BDataType* p_b_grid_;
388 EDataType* p_e_grid_;
389
390 // tensor descriptors for problem definiton
395
396 // tensor descriptors for block/thread-wise copy
399
400 // block-to-e-tile map
402
403 // element-wise op
404 AElementwiseOperation a_element_op_;
405 BElementwiseOperation b_element_op_;
406 CDEElementwiseOperation cde_element_op_;
407
408 // for checking vector load/store
412 };
413
414 // Invoker
415 struct Invoker : public BaseInvoker
416 {
418
419 template <typename GridwiseGemm>
420 float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
421 {
422 if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
427 {
428 throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
429 }
430 auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
431 GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
433
434 auto e_grid_desc_mblock_mperblock_nblock_nperblock =
435 GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
436 arg.e_grid_desc_m_n_);
437
438 const index_t grid_size =
439 arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_);
440
441 auto launch_kernel = [&](auto has_main_k_block_loop) {
442 constexpr bool has_main_loop = has_main_k_block_loop.value;
443
444 const auto kernel = kernel_gemm_multiple_d_xdl_cshuffle<
445 GridwiseGemm,
446 ADataType, // TODO: distiguish A/B datatype
447 BDataType, // TODO: distiguish A/B datatype
448 typename GridwiseGemm::DsGridPointer,
449 EDataType,
450 AElementwiseOperation,
451 BElementwiseOperation,
452 CDEElementwiseOperation,
458 has_main_loop>;
459
460 return launch_and_time_kernel(stream_config,
461 kernel,
462 dim3(grid_size),
463 dim3(BlockSize),
464 0,
465 arg.p_a_grid_,
466 arg.p_b_grid_,
467 arg.p_ds_grid_,
468 arg.p_e_grid_,
469 arg.a_element_op_,
470 arg.b_element_op_,
471 arg.cde_element_op_,
474 ds_grid_desc_mblock_mperblock_nblock_nperblock,
475 e_grid_desc_mblock_mperblock_nblock_nperblock,
477 };
478
479 const auto K = arg.a_grid_desc_m_k_.GetLength(I1);
480
481 if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
482 {
483 return launch_kernel(integral_constant<bool, true>{});
484 }
485 else
486 {
487 return launch_kernel(integral_constant<bool, false>{});
488 }
489 }
490
492
493 // polymorphic
494 float Run(const BaseArgument* p_arg,
495 const StreamConfig& stream_config = StreamConfig{}) override
496 {
497 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
498 }
499 };
500
501#endif
502
503 static constexpr bool IsSupported(index_t MRaw_, index_t NRaw_, index_t KRaw_)
504 {
505 // check vector load/store
508 // check vector load of A
509 if constexpr(is_same_v<ALayout, Row> && ABlockTransferSrcVectorDim == 2)
510 {
511 if(KRaw_ % ABlockTransferSrcScalarPerVector != 0)
512 {
513 return false;
514 }
515 }
516 else if constexpr(is_same_v<ALayout, Col> && ABlockTransferSrcVectorDim == 1)
517 {
518 // FIXME: not rigorous
519 if(MRaw_ % ABlockTransferSrcScalarPerVector != 0)
520 {
521 return false;
522 }
523 }
524 else
525 {
526 return false;
527 }
528 // check vector laod of B
529 if constexpr(is_same_v<BLayout, Col> && BBlockTransferSrcVectorDim == 2)
530 {
531 if(KRaw_ % BBlockTransferSrcScalarPerVector != 0)
532 {
533 return false;
534 }
535 }
536 else if constexpr(is_same_v<BLayout, Row> && BBlockTransferSrcVectorDim == 1)
537 {
538 // FIXME: not rigorous
539 if(NRaw_ % BBlockTransferSrcScalarPerVector != 0)
540 {
541 return false;
542 }
543 }
544 else
545 {
546 return false;
547 }
548
549 // check vector load of Ds
550 // only support RowMajor for now
551 bool all_valid = true;
552
553 static_for<0, NumDTensor, 1>{}([&](auto i) {
554 using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
555
556 if constexpr(!is_same_v<DLayout, Row>)
557 {
558 all_valid = false;
559 }
560 });
561
562 if(!all_valid)
563 {
564 return false;
565 }
566
567 // check vector store of E
568 // only support RowMajor for now
569 if constexpr(is_same_v<ELayout, Row>)
570 {
571 if(NRaw_ % CDEBlockTransferScalarPerVector_NPerBlock != 0)
572 {
573 return false;
574 }
575 }
576 else
577 {
578 return false;
579 }
580 return true;
581 }
582
583#ifndef __HIPCC_RTC__
584 static bool IsSupportedArgument(const Argument& arg)
585 {
587 {
588 return false;
589 }
590 if(!IsSupported(arg.MRaw_, arg.NRaw_, arg.KRaw_))
591 {
592 return false;
593 }
594
595 if(get_warp_size() == 64)
596 {
597 if constexpr(NXdlPerWave64 > 0)
598 {
604 }
605 }
606 else
607 {
608 if constexpr(NXdlPerWave32 > 0)
609 {
615 }
616 }
617 return false;
618 }
619
620 // polymorphic
621 bool IsSupportedArgument(const BaseArgument* p_arg) override
622 {
623 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
624 }
625
626 static auto MakeArgument(const void* p_a,
627 const void* p_b,
628 std::array<const void*, NumDTensor> p_ds,
629 void* p_e,
630 index_t MRaw,
631 index_t NRaw,
632 index_t KRaw,
633 index_t StrideA,
634 index_t StrideB,
635 std::array<index_t, NumDTensor> StrideDs,
636 index_t StrideE,
637 AElementwiseOperation a_element_op,
638 BElementwiseOperation b_element_op,
639 CDEElementwiseOperation cde_element_op)
640 {
641 return Argument{p_a,
642 p_b,
643 p_ds,
644 p_e,
645 MRaw,
646 NRaw,
647 KRaw,
648 StrideA,
649 StrideB,
650 StrideDs,
651 StrideE,
652 a_element_op,
653 b_element_op,
654 cde_element_op};
655 }
656
657 static auto MakeInvoker() { return Invoker{}; }
658
659 // polymorphic
660 std::unique_ptr<BaseArgument>
661 MakeArgumentPointer(const void* p_a,
662 const void* p_b,
663 std::array<const void*, NumDTensor> p_ds,
664 void* p_e,
665 index_t MRaw,
666 index_t NRaw,
667 index_t KRaw,
668 index_t StrideA,
669 index_t StrideB,
670 std::array<ck::index_t, NumDTensor> StrideDs,
671 index_t StrideE,
672 AElementwiseOperation a_element_op,
673 BElementwiseOperation b_element_op,
674 CDEElementwiseOperation cde_element_op) override
675 {
676 return std::make_unique<Argument>(p_a,
677 p_b,
678 p_ds,
679 p_e,
680 MRaw,
681 NRaw,
682 KRaw,
683 StrideA,
684 StrideB,
685 StrideDs,
686 StrideE,
687 a_element_op,
688 b_element_op,
689 cde_element_op);
690 }
691
692 // polymorphic
693 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
694 {
695 return std::make_unique<Invoker>(Invoker{});
696 }
697
698 // polymorphic
699 std::string GetTypeString() const override
700 {
701 auto str = std::stringstream();
702
703 std::map<LoopScheduler, std::string> LoopSchedToString{{LoopScheduler::Default, "Default"},
705 "Interwave" }};
706
707 std::map<PipelineVersion, std::string> PipelineVersionToString{{PipelineVersion::v1, "v1"},
709 "v2" }};
710
711 // clang-format off
712 str << "DeviceGemmMultipleD_Xdl_CShuffle"
713 << "<"
714 << BlockSize << ", "
715 << MPerBlock << ", "
716 << NPerBlock << ", "
717 << KPerBlock << ", "
718 << AK1 << ", "
719 << BK1 << ", "
720 << MPerXDL << ", "
721 << NPerXDL << ", "
722 << MXdlPerWave << ", "
723 << NXdlPerWave << ", "
724 << ABlockTransferSrcScalarPerVector << ", "
725 << BBlockTransferSrcScalarPerVector << ", "
726 << CShuffleMXdlPerWavePerShuffle << ", "
727 << CShuffleNXdlPerWavePerShuffle << ", "
728 << getGemmSpecializationString(GemmSpec)
729 << ">"
730 << " LoopScheduler: "
731 << LoopSchedToString[LoopSched] << ", "
732 << "PipelineVersion: "
733 << PipelineVersionToString[PipelineVer];
734 // clang-format on
735
736 return str.str();
737 }
738#endif
739
740 template <class ADesc, class BDesc, class DsDesc, class EDesc>
742 {
743 static constexpr auto ds_tuple()
744 {
745 return transform_tuples(
746 [&](auto d) constexpr { return DeviceOp::matrix_padder.PadCDescriptor_M_N(d); },
747 DsDesc{});
748 }
750 remove_cvref_t<decltype(DeviceOp::matrix_padder.PadADescriptor_M_K(ADesc{}))>;
752 remove_cvref_t<decltype(DeviceOp::matrix_padder.PadBDescriptor_N_K(BDesc{}))>;
755 remove_cvref_t<decltype(DeviceOp::matrix_padder.PadCDescriptor_M_N(EDesc{}))>;
758 DeviceOp::matrix_padder.PadADescriptor_M_K(ADesc{})))>;
761 DeviceOp::matrix_padder.PadBDescriptor_N_K(BDesc{})))>;
764 ds_tuple()))>;
767 DeviceOp::matrix_padder.PadCDescriptor_M_N(EDesc{})))>;
769 DeviceOp::matrix_padder.PadCDescriptor_M_N(EDesc{})))>;
770
771 // tensor descriptors for problem definiton
776
777 // tensor descriptors for block/thread-wise copy
782
783 // block-to-e-tile map
785
786 // element-wise op
787 AElementwiseOperation a_element_op;
788 BElementwiseOperation b_element_op;
789 CDEElementwiseOperation cde_element_op;
790
791 // for checking vector load/store
795
797
798 constexpr Descriptor(ADesc a,
799 BDesc b,
800 DsDesc ds,
801 EDesc e,
802 AElementwiseOperation a_element_op_,
803 BElementwiseOperation b_element_op_,
804 CDEElementwiseOperation cde_element_op_)
805 : a_grid_desc_m_k{DeviceOp::matrix_padder.PadADescriptor_M_K(a)},
806 b_grid_desc_n_k{DeviceOp::matrix_padder.PadBDescriptor_N_K(b)},
808 [&](auto d) constexpr { return DeviceOp::matrix_padder.PadCDescriptor_M_N(d); },
809 ds)},
810 e_grid_desc_m_n{DeviceOp::matrix_padder.PadCDescriptor_M_N(e)},
811 a_grid_desc_ak0_m_ak1{
812 GridwiseGemm64::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k)},
813 b_grid_desc_bk0_n_bk1{
814 GridwiseGemm64::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k)},
815 ds_grid_desc_mblock_mperblock_nblock_nperblock{
816 GridwiseGemm64::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
818 [&](auto d) constexpr {
819 return DeviceOp::matrix_padder.PadCDescriptor_M_N(d);
820 },
821 ds))},
822 e_grid_desc_mblock_mperblock_nblock_nperblock{
823 GridwiseGemm64::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
824 e_grid_desc_m_n)},
825 block_2_etile_map{GridwiseGemm64::MakeDefaultBlock2ETileMap(e_grid_desc_m_n)},
826 has_main_k_block_loop{GridwiseGemm64::CalculateHasMainKBlockLoop(
827 a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2))},
828 a_element_op{a_element_op_},
829 b_element_op{b_element_op_},
830 cde_element_op{cde_element_op_},
831 MRaw{e.GetLength(I0)},
832 NRaw{e.GetLength(I1)},
833 KRaw{a.GetLength(I1)}
834 {
835 }
836
837 constexpr bool IsValid() const
838 {
839 if(get_warp_size() == 64)
840 {
841 if constexpr(NXdlPerWave64 > 0)
842 {
849 GridwiseGemm64::template IsValidCompilationParameter<>();
850 }
851 }
852 else
853 {
854 if constexpr(NXdlPerWave32 > 0)
855 {
862 GridwiseGemm32::template IsValidCompilationParameter<>();
863 }
864 }
865 return false;
866 }
867
868 constexpr index_t GetBlockSize() const { return BlockSize; }
869
870 constexpr index_t GetGridSize() const
871 {
872 return block_2_etile_map.CalculateGridSize(e_grid_desc_m_n);
873 }
874 };
875
876 template <class ADesc, class BDesc, class DsDesc, class EDesc>
877 static constexpr auto
879 BDesc b,
880 DsDesc ds,
881 EDesc e,
882 AElementwiseOperation a_element_op = AElementwiseOperation{},
883 BElementwiseOperation b_element_op = BElementwiseOperation{},
884 CDEElementwiseOperation cde_element_op = CDEElementwiseOperation{})
885 {
886 return Descriptor<ADesc, BDesc, DsDesc, EDesc>(
887 a, b, ds, e, a_element_op, b_element_op, cde_element_op);
888 }
889
890 template <class Desc, class DsPointer>
891 __device__ static void Run(const Desc& desc,
892 const ADataType* __restrict__ p_a_grid,
893 const BDataType* __restrict__ p_b_grid,
894 DsPointer p_ds_grid,
895 EDataType* __restrict__ p_e_grid)
896 {
897
898#ifndef __HIPCC_RTC__
899 assert(desc.IsValid());
900#endif
901 using GridwiseGemm = conditional_t<get_warp_size() == 64, GridwiseGemm64, GridwiseGemm32>;
902 __shared__ char p_shared_block[GridwiseGemm::GetSharedMemoryNumberOfByte()];
903 if(desc.has_main_k_block_loop)
904 {
905 GridwiseGemm::template Run<true, InMemoryDataOperationEnum::Set>(
906 p_a_grid,
907 p_b_grid,
908 p_ds_grid,
909 p_e_grid,
910 p_shared_block,
911 desc.a_element_op,
912 desc.b_element_op,
913 desc.cde_element_op,
914 desc.a_grid_desc_ak0_m_ak1,
915 desc.b_grid_desc_bk0_n_bk1,
916 desc.ds_grid_desc_mblock_mperblock_nblock_nperblock,
917 desc.e_grid_desc_mblock_mperblock_nblock_nperblock,
918 desc.block_2_etile_map);
919 }
920 else
921 {
922 GridwiseGemm::template Run<false, InMemoryDataOperationEnum::Set>(
923 p_a_grid,
924 p_b_grid,
925 p_ds_grid,
926 p_e_grid,
927 p_shared_block,
928 desc.a_element_op,
929 desc.b_element_op,
930 desc.cde_element_op,
931 desc.a_grid_desc_ak0_m_ak1,
932 desc.b_grid_desc_bk0_n_bk1,
933 desc.ds_grid_desc_mblock_mperblock_nblock_nperblock,
934 desc.e_grid_desc_mblock_mperblock_nblock_nperblock,
935 desc.block_2_etile_map);
936 }
937 }
938};
939
940} // namespace device
941} // namespace tensor_operation
942} // namespace ck
#define CK_MIN_BLOCK_PER_CU
Definition ck.hpp:31
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
#define GET_NXDL_PER_WAVE_IMPL
Definition device_base.hpp:81
#define INVOKER_RUN_IMPL
Definition device_base.hpp:94
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
Definition convolution_backward_data_specialization.hpp:8
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition gemm_specialization.hpp:32
GemmSpecialization
Definition gemm_specialization.hpp:11
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
typename conditional< predicate, X, Y >::type conditional_t
Definition utility/functional.hpp:115
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
__global__ void kernel_gemm_multiple_d_xdl_cshuffle(const ADataType *__restrict__ p_a_grid, const BDataType *__restrict__ p_b_grid, DsPointer p_ds_grid, EDataType *__restrict__ p_e_grid, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CDEElementwiseOperation cde_element_op, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock, const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock, const Block2ETileMap block_2_etile_map)
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:42
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
integral_constant< index_t, N > Number
Definition number.hpp:12
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__host__ __device__ constexpr auto transform_tuples(F f, const X &x)
Definition tuple_helper.hpp:98
bool is_xdl_wmma_supported()
Definition host_utility/device_prop.hpp:76
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
constexpr bool is_same_v
Definition type.hpp:283
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
LoopScheduler
Definition loop_scheduler.hpp:15
@ Default
Definition loop_scheduler.hpp:16
@ Interwave
Definition loop_scheduler.hpp:17
PipelineVersion
Definition gridwise_gemm_pipeline_selector.hpp:18
@ v2
Definition gridwise_gemm_pipeline_selector.hpp:20
@ v1
Definition gridwise_gemm_pipeline_selector.hpp:19
constexpr LoopScheduler make_default_loop_scheduler()
Definition loop_scheduler.hpp:20
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition pointer.h:1517
Definition ck/stream_config.hpp:10
Definition utility/array.hpp:14
Definition gridwise_gemm_multiple_d_xdl_cshuffle.hpp:78
static constexpr value_type value
Definition utility/integral_constant.hpp:13
Definition utility/integral_constant.hpp:20
Definition functional2.hpp:33
Definition tensor_operation/gpu/device/tensor_layout.hpp:31
Definition tensor_operation/gpu/device/tensor_layout.hpp:26
Definition device_base.hpp:197
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:325
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:398
index_t MRaw_
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:409
const BDataType * p_b_grid_
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:386
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:397
EGridDesc_M_N e_grid_desc_m_n_
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:394
BElementwiseOperation b_element_op_
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:405
index_t KRaw_
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:411
index_t NRaw_
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:410
EDataType * p_e_grid_
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:388
AGridDesc_M_K a_grid_desc_m_k_
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:391
Block2ETileMap block_2_etile_map_
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:401
AElementwiseOperation a_element_op_
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:404
void Print() const
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:374
DsGridDesc_M_N ds_grid_desc_m_n_
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:393
const ADataType * p_a_grid_
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:385
CDEElementwiseOperation cde_element_op_
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:406
Argument(const void *p_a_grid, const void *p_b_grid, std::array< const void *, NumDTensor > p_ds_grid, void *p_e_grid, index_t MRaw, index_t NRaw, index_t KRaw, index_t StrideA, index_t StrideB, std::array< index_t, NumDTensor > StrideDs, index_t StrideE, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:326
BGridDesc_N_K b_grid_desc_n_k_
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:392
GridwiseGemm64::DsGridPointer p_ds_grid_
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:387
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:779
index_t NRaw
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:793
index_t MRaw
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:792
AElementwiseOperation a_element_op
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:787
constexpr index_t GetGridSize() const
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:870
DsGridDesc_M_N ds_grid_desc_m_n
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:774
remove_cvref_t< decltype(GridwiseGemm64::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( DeviceOp::matrix_padder.PadCDescriptor_M_N(EDesc{})))> EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:765
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:778
constexpr bool IsValid() const
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:837
remove_cvref_t< decltype(DeviceOp::matrix_padder.PadCDescriptor_M_N(EDesc{}))> EGridDesc_M_N
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:754
EGridDesc_M_N e_grid_desc_m_n
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:775
Block2ETileMap block_2_etile_map
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:784
constexpr Descriptor(ADesc a, BDesc b, DsDesc ds, EDesc e, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, CDEElementwiseOperation cde_element_op_)
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:798
AGridDesc_M_K a_grid_desc_m_k
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:772
remove_cvref_t< decltype(ds_tuple())> DsGridDesc_M_N
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:753
remove_cvref_t< decltype(DeviceOp::matrix_padder.PadADescriptor_M_K(ADesc{}))> AGridDesc_M_K
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:749
constexpr index_t GetBlockSize() const
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:868
index_t KRaw
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:794
remove_cvref_t< decltype(DeviceOp::matrix_padder.PadBDescriptor_N_K(BDesc{}))> BGridDesc_N_K
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:751
remove_cvref_t< decltype(GridwiseGemm64::MakeDefaultAGridDescriptor_AK0_M_AK1( DeviceOp::matrix_padder.PadADescriptor_M_K(ADesc{})))> AGridDesc_AK0_M_AK1
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:756
remove_cvref_t< decltype(GridwiseGemm64::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( ds_tuple()))> DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:762
bool has_main_k_block_loop
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:796
remove_cvref_t< decltype(GridwiseGemm64::MakeDefaultBlock2ETileMap( DeviceOp::matrix_padder.PadCDescriptor_M_N(EDesc{})))> Block2ETileMap
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:768
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:780
BElementwiseOperation b_element_op
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:788
CDEElementwiseOperation cde_element_op
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:789
static constexpr auto ds_tuple()
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:743
remove_cvref_t< decltype(GridwiseGemm64::MakeDefaultBGridDescriptor_BK0_N_BK1( DeviceOp::matrix_padder.PadBDescriptor_N_K(BDesc{})))> BGridDesc_BK0_N_BK1
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:759
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:781
BGridDesc_N_K b_grid_desc_n_k
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:773
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:416
DeviceOp::Argument Argument
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:417
float RunImp(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:420
INVOKER_RUN_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:494
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:165
remove_cvref_t< decltype(GridwiseGemm64::MakeDefaultBGridDescriptor_BK0_N_BK1( BGridDesc_N_K{}))> BGridDesc_BK0_N_BK1
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:308
GridwiseGemmMultipleD_xdl_cshuffle< ADataType, BDataType, ComputeDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVer > GridwiseGemmBase
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:257
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:693
static constexpr auto I1
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:174
std::string GetTypeString() const override
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:699
remove_cvref_t< decltype(GridwiseGemm64::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( EGridDesc_M_N{}))> EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:314
remove_cvref_t< decltype(MakeDsGridDescriptor_M_N({}, {}, {}))> DsGridDesc_M_N
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:252
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:168
static constexpr bool IsSupported(index_t MRaw_, index_t NRaw_, index_t KRaw_)
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:503
decltype(MakeBGridDescriptor_N_K(1, 1, 1)) BGridDesc_N_K
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:251
DeviceGemmMultipleD_Xdl_CShuffle DeviceOp
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:166
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:301
static auto MakeArgument(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_e, index_t MRaw, index_t NRaw, index_t KRaw, index_t StrideA, index_t StrideB, std::array< index_t, NumDTensor > StrideDs, index_t StrideE, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:626
GridwiseGemmBase< NXdlPerWave32 > GridwiseGemm32
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:302
static constexpr auto make_descriptor(ADesc a, BDesc b, DsDesc ds, EDesc e, AElementwiseOperation a_element_op=AElementwiseOperation{}, BElementwiseOperation b_element_op=BElementwiseOperation{}, CDEElementwiseOperation cde_element_op=CDEElementwiseOperation{})
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:878
remove_cvref_t< decltype(GridwiseGemm64::MakeDefaultAGridDescriptor_AK0_M_AK1( AGridDesc_M_K{}))> AGridDesc_AK0_M_AK1
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:305
static constexpr auto I0
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:173
decltype(MakeAGridDescriptor_M_K(1, 1, 1)) AGridDesc_M_K
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:250
static constexpr auto matrix_padder
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:178
decltype(MakeEGridDescriptor_M_N< ELayout >(1, 1, 1)) EGridDesc_M_N
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:253
static __device__ void Run(const Desc &desc, const ADataType *__restrict__ p_a_grid, const BDataType *__restrict__ p_b_grid, DsPointer p_ds_grid, EDataType *__restrict__ p_e_grid)
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:891
static constexpr index_t NumDTensor
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:171
static constexpr auto NXdlPerWave32
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:169
static auto MakeInvoker()
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:657
remove_cvref_t< decltype(GridwiseGemm64::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( DsGridDesc_M_N{}))> DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:311
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_e, index_t MRaw, index_t NRaw, index_t KRaw, index_t StrideA, index_t StrideB, std::array< ck::index_t, NumDTensor > StrideDs, index_t StrideE, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op) override
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:661
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:621
static auto MakeDsGridDescriptor_M_N(const Array< index_t, NumDTensor > &MRaws, const Array< index_t, NumDTensor > &NRaws, const Array< index_t, NumDTensor > &DsStride)
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:236
static bool IsSupportedArgument(const Argument &arg)
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:584
static auto MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB)
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:199
static constexpr auto I3
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:176
static auto MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE)
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:218
static auto MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA)
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:181
remove_cvref_t< decltype(GridwiseGemm64::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))> Block2ETileMap
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:319
static constexpr auto I2
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:175
Definition device_gemm_multiple_d.hpp:36
Definition matrix_padder.hpp:180