device_grouped_conv_bwd_weight_xdl_cshuffle.hpp Source File

device_grouped_conv_bwd_weight_xdl_cshuffle.hpp Source File#

Composable Kernel: device_grouped_conv_bwd_weight_xdl_cshuffle.hpp Source File
device_grouped_conv_bwd_weight_xdl_cshuffle.hpp
Go to the documentation of this file.
1// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
2// SPDX-License-Identifier: MIT
3
4#pragma once
5
6#include <iostream>
7#include <numeric>
8#include <sstream>
9
11#include "ck/utility/env.hpp"
26
27#ifdef CK_EXPERIMENTAL_BUILDER
28#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp"
29#endif
30
31namespace ck {
32namespace tensor_operation {
33namespace device {
34
35template <typename GridwiseGemm,
36 typename FloatA,
37 typename FloatB,
38 typename FloatC,
39 typename AElementwiseOperation,
40 typename BElementwiseOperation,
41 typename CElementwiseOperation,
42 typename AGridDesc_B_K0_M_K1,
43 typename BGridDesc_B_K0_N_K1,
44 typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
45 typename Block2CTileMap,
46 typename ComputePtrOffsetOfBatch,
47 bool HasMainKBlockLoop>
48__global__ void
49#if CK_USE_LAUNCH_BOUNDS
51#endif
52 kernel_batched_gemm_xdlops_bwd_weight(const FloatA* __restrict__ p_a_grid,
53 const FloatB* __restrict__ p_b_grid,
54 FloatC* __restrict__ p_c_grid,
55 const AElementwiseOperation a_element_op,
56 const BElementwiseOperation b_element_op,
57 const CElementwiseOperation c_element_op,
58 const index_t batch_count,
59 const AGridDesc_B_K0_M_K1 a_b_k0_m_k1_grid_desc,
60 const BGridDesc_B_K0_N_K1 b_b_k0_n_k1_grid_desc,
61 const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
62 c_grid_desc_mblock_mperblock_nblock_nperblock,
63 const Block2CTileMap block_2_ctile_map,
64 const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
65{
66#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx11__) || \
67 defined(__gfx12__)
68 if constexpr(GridwiseGemm::template IsValidCompilationParameter<>())
69 {
70 const index_t num_blocks_per_batch =
71 __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
72 const index_t g_idx =
73 __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
74
75 const long_index_t a_batch_offset = compute_ptr_offset_of_batch.GetAPtrOffset(g_idx);
76 const long_index_t b_batch_offset = compute_ptr_offset_of_batch.GetBPtrOffset(g_idx);
77 const long_index_t c_batch_offset = compute_ptr_offset_of_batch.GetCPtrOffset(g_idx);
78
79 __shared__ FloatA p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatA)];
80
81 GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
82 p_b_grid + b_batch_offset,
83 p_c_grid + c_batch_offset,
84 p_shared,
85 a_b_k0_m_k1_grid_desc,
86 b_b_k0_n_k1_grid_desc,
87 c_grid_desc_mblock_mperblock_nblock_nperblock,
88 a_element_op,
89 b_element_op,
90 c_element_op,
91 block_2_ctile_map);
92 }
93#else
94 ignore = p_a_grid;
95 ignore = p_b_grid;
96 ignore = p_c_grid;
97 ignore = a_b_k0_m_k1_grid_desc;
98 ignore = b_b_k0_n_k1_grid_desc;
99 ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
100 ignore = a_element_op;
101 ignore = b_element_op;
102 ignore = c_element_op;
103 ignore = batch_count;
104 ignore = block_2_ctile_map;
105 ignore = compute_ptr_offset_of_batch;
106
107 compute_ptr_offset_of_batch.GetAPtrOffset(0);
108 compute_ptr_offset_of_batch.GetBPtrOffset(0);
109 compute_ptr_offset_of_batch.GetCPtrOffset(0);
110#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
111}
112
113// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C]
114template <ck::index_t NDimSpatial,
115 typename InLayout,
116 typename WeiLayout,
117 typename OutLayout,
118 typename InDataType,
119 typename WeiDataType,
120 typename OutDataType,
121 typename AccDataType,
122 typename InElementwiseOperation,
123 typename WeiElementwiseOperation,
124 typename OutElementwiseOperation,
125 ConvolutionBackwardWeightSpecialization ConvBackwardWeightSpecialization,
126 ck::index_t BlockSize,
127 ck::index_t MPerBlock,
128 ck::index_t NPerBlock,
129 ck::index_t K0PerBlock,
130 ck::index_t K1,
131 ck::index_t MPerXDL,
132 ck::index_t NPerXDL,
133 ck::index_t MXdlPerWave,
134 ck::index_t NXdlPerWave,
135 typename ABlockTransferThreadClusterLengths_K0_M_K1,
136 typename ABlockTransferThreadClusterArrangeOrder,
137 typename ABlockTransferSrcAccessOrder,
138 ck::index_t ABlockTransferSrcVectorDim,
139 ck::index_t ABlockTransferSrcScalarPerVector,
140 ck::index_t ABlockTransferDstScalarPerVector_K1,
141 bool ABlockLdsAddExtraM,
142 typename BBlockTransferThreadClusterLengths_K0_N_K1,
143 typename BBlockTransferThreadClusterArrangeOrder,
144 typename BBlockTransferSrcAccessOrder,
145 ck::index_t BBlockTransferSrcVectorDim,
146 ck::index_t BBlockTransferSrcScalarPerVector,
147 ck::index_t BBlockTransferDstScalarPerVector_K1,
148 bool BBlockLdsAddExtraN,
149 index_t CShuffleMXdlPerWavePerShuffle,
150 index_t CShuffleNXdlPerWavePerShuffle,
151 typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
152 index_t CBlockTransferScalarPerVector_NWaveNPerXdl,
153 typename ComputeTypeA = InDataType,
154 typename ComputeTypeB = ComputeTypeA,
155 index_t MaxTransposeTransferSrcScalarPerVector = 1,
156 index_t MaxTransposeTransferDstScalarPerVector = 1>
158 : public DeviceGroupedConvBwdWeight<NDimSpatial,
159 InLayout,
160 WeiLayout,
161 OutLayout,
162 InDataType,
163 WeiDataType,
164 OutDataType,
165 InElementwiseOperation,
166 WeiElementwiseOperation,
167 OutElementwiseOperation,
168 ComputeTypeA,
169 ComputeTypeB>
170{
173 static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
174 static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
175
176 using ADataType = OutDataType;
177 using BDataType = InDataType;
178 using CDataType = WeiDataType;
179
180 // If NGCHW then ADataType must be equal to BDataType
184
185 using AElementwiseOperation = OutElementwiseOperation;
186 using BElementwiseOperation = InElementwiseOperation;
187 using CElementwiseOperation = WeiElementwiseOperation;
188
189 // TODO make A/B datatype different
190 using ABDataType = InDataType;
191
192 static constexpr auto I0 = Number<0>{};
193 static constexpr auto I1 = Number<1>{};
194 static constexpr auto I2 = Number<2>{};
195 static constexpr auto I3 = Number<3>{};
196 static constexpr auto I4 = Number<4>{};
197 static constexpr auto I5 = Number<5>{};
198
199 static constexpr auto K1Number = Number<K1>{};
200
201 static constexpr auto conv_to_gemm_transformer =
203 MPerBlock,
204 NPerBlock,
205 K1Number,
206 K0PerBlock,
207 ConvBackwardWeightSpecialization>{};
208
209 // Bytes per 32 lds bank: 32 * 4 bytes
210 static constexpr auto BankLength = 128;
211 static constexpr auto ElePerBank = BankLength / sizeof(ADataType);
212
213 // M1 & M0
214 static constexpr auto ABlockLdsM1PerBlock = ElePerBank / K1;
215 static constexpr auto ABlockLdsM0PerBlock = MPerBlock / ABlockLdsM1PerBlock;
216 static constexpr auto ABlockLdsM1Padding = 4;
217
218 // N1 & N0
219 static constexpr auto BBlockLdsN1PerBlock = ElePerBank / K1;
220 static constexpr auto BBlockLdsN0PerBlock = NPerBlock / BBlockLdsN1PerBlock;
221 static constexpr auto BBlockLdsN1Padding = 4;
222
223 template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
224 static auto GetABCGridDesc()
225 {
226 const ck::index_t dim = 1;
227 const ck::index_t batch = 1;
228 const std::array<ck::index_t, NDimSpatial> lengths{1};
229 const std::array<ck::index_t, NDimSpatial + 3> strides{1, 1, 1, 1};
230 const std::array<ck::index_t, NDimSpatial> params{1};
231 return conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<1>(
232 dim,
233 dim,
234 dim,
235 lengths,
236 lengths,
237 lengths,
238 strides,
239 strides,
240 strides,
241 params,
242 params,
243 params,
244 params,
245 batch);
246 }
247
248 template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
249 static auto GetABCGridDesc()
250 {
251 const ck::index_t dim = 1;
252 const ck::index_t batch = 1;
253 const std::array<ck::index_t, NDimSpatial> lengths{1, 1};
254 const std::array<ck::index_t, NDimSpatial + 3> strides{1, 1, 1, 1, 1};
255 const std::array<ck::index_t, NDimSpatial> params{1, 1};
256 return conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<2>(
257 dim,
258 dim,
259 dim,
260 lengths,
261 lengths,
262 lengths,
263 strides,
264 strides,
265 strides,
266 params,
267 params,
268 params,
269 params,
270 batch);
271 }
272
273 template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
274 static auto GetABCGridDesc()
275 {
276 const ck::index_t dim = 1;
277 const ck::index_t batch = 1;
278 const std::array<ck::index_t, NDimSpatial> lengths{1, 1, 1};
279 const std::array<ck::index_t, NDimSpatial + 3> strides{1, 1, 1, 1, 1, 1};
280 const std::array<ck::index_t, NDimSpatial> params{1, 1, 1};
281 return conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<3>(
282 dim,
283 dim,
284 dim,
285 lengths,
286 lengths,
287 lengths,
288 strides,
289 strides,
290 strides,
291 params,
292 params,
293 params,
294 params,
295 batch);
296 }
297
299
303
305 CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(1);
307 CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(3);
308
309 static constexpr auto conv_ngchw_to_nhwgc_transformer =
311 WeiLayout,
312 OutLayout,
313 NDimSpatial,
314 MPerBlock / ClusterLengthMPerBlock,
315 NPerBlock / ClusterLengthNPerBlock>{};
316
318
320 std::min(NPerBlock / ClusterLengthNPerBlock, MaxTransposeTransferSrcScalarPerVector);
322 std::min(MPerBlock / ClusterLengthMPerBlock, MaxTransposeTransferDstScalarPerVector);
323
326 .template MakeNGCHWTransposeDesc<NDimSpatial>({}, {}))>;
329 .template MakeNHWGCTransposeDesc<NDimSpatial>({}, {}))>;
332 .template MakeGKCYXTransposeDesc<NDimSpatial>({}, {}))>;
335 .template MakeGKYXCTransposeDesc<NDimSpatial>({}, {}))>;
336
344 BlockSize,
345 MPerBlock,
346 NPerBlock,
347 MPerBlock / ClusterLengthMPerBlock,
348 NPerBlock / ClusterLengthNPerBlock,
352 I1,
353 I0>;
354
355 // NPerBlock is used for the first dim which is store dimension
356 // (with CBlockTransferScalarPerVector_NWaveNPerXdl scalar per vector).
364 BlockSize,
365 MPerBlock,
366 NPerBlock,
367 MPerBlock / ClusterLengthMPerBlock,
368 NPerBlock / ClusterLengthNPerBlock,
372 I1,
373 I0>;
374
375 template <index_t NXdlPerWave_>
377 BlockSize,
378 ADataType,
379 BDataType,
380 AccDataType,
381 CDataType,
389 MPerBlock,
390 NPerBlock,
391 K0PerBlock,
392 MPerXDL,
393 NPerXDL,
394 K1,
395 MXdlPerWave,
396 NXdlPerWave_,
397 ABlockTransferThreadClusterLengths_K0_M_K1,
398 ABlockTransferThreadClusterArrangeOrder,
399 ABlockTransferSrcAccessOrder,
400 ABlockTransferSrcVectorDim,
401 ABlockTransferSrcScalarPerVector,
402 ABlockTransferDstScalarPerVector_K1,
403 false, // AThreadTransferSrcResetCoordinateAfterRun,
404 ABlockLdsAddExtraM,
408 BBlockTransferThreadClusterLengths_K0_N_K1,
409 BBlockTransferThreadClusterArrangeOrder,
410 BBlockTransferSrcAccessOrder,
411 BBlockTransferSrcVectorDim,
412 BBlockTransferSrcScalarPerVector,
413 BBlockTransferDstScalarPerVector_K1,
414 false, // BThreadTransferSrcResetCoordinateAfterRun,
415 BBlockLdsAddExtraN,
419 CShuffleMXdlPerWavePerShuffle,
420 CShuffleNXdlPerWavePerShuffle,
421 CBlockTransferScalarPerVector_NWaveNPerXdl,
422 CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
423 true,
424 true,
425 1,
427 ComputeTypeA,
428 ComputeTypeB>;
431
432 // Argument
435
438
440 {
441 template <typename GridwiseGemm>
442 static int GetMaxOccupancy()
443 {
444 constexpr int dynamic_smem_size = 0;
445 int max_occupancy = 0;
446 hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor(
447 &max_occupancy,
449 GridwiseGemm,
450 ADataType,
451 BDataType,
452 CDataType,
453 OutElementwiseOperation,
454 InElementwiseOperation,
455 WeiElementwiseOperation,
460 ComputePtrOffsetOfStridedBatch<>,
461 false>, // Both true/false give the same occupancy.
462 BlockSize,
463 dynamic_smem_size));
464 return std::max(1, max_occupancy);
465 }
467 {
468 max_occupancy_ = 1;
469 if(get_warp_size() == 64)
470 {
471 if constexpr(NXdlPerWave64 > 0)
472 {
474 }
475 }
476 else
477 {
478 if constexpr(NXdlPerWave32 > 0)
479 {
481 }
482 }
483 }
485 };
486
487 struct Argument : public BaseArgument, public ArgumentSplitK
488 {
489 Argument(const InDataType* p_in_grid,
490 WeiDataType* p_wei_grid,
491 const OutDataType* p_out_grid,
492 const std::array<index_t, NDimSpatial + 3>& b_g_n_c_wis_lengths, // input
493 const std::array<index_t, NDimSpatial + 3>& b_g_n_c_wis_strides,
494 const std::array<index_t, NDimSpatial + 3>& e_g_k_c_xs_lengths, // weight
495 const std::array<index_t, NDimSpatial + 3>& e_g_k_c_xs_strides,
496 const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_lengths, // output
497 const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_strides,
498 const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
499 const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
500 const std::array<ck::index_t, NDimSpatial>& input_left_pads,
501 const std::array<ck::index_t, NDimSpatial>& input_right_pads,
502 const ck::index_t M01,
503 const ck::index_t N01,
504 InElementwiseOperation in_element_op,
505 WeiElementwiseOperation wei_element_op,
506 OutElementwiseOperation out_element_op,
507 ck::index_t split_k)
508 : p_a_grid_{p_out_grid},
509 p_b_grid_{p_in_grid},
510 p_c_grid_{p_wei_grid},
516 M01_{M01},
517 N01_{N01},
518 a_element_op_{out_element_op},
519 b_element_op_{in_element_op},
520 c_element_op_{wei_element_op},
521 Conv_G_{b_g_n_c_wis_lengths[0]},
522 Conv_N_{b_g_n_c_wis_lengths[1]},
523 Conv_K_{e_g_k_c_xs_lengths[1]},
524 Conv_C_{b_g_n_c_wis_lengths[2]},
528 conv_filter_strides_{conv_filter_strides},
529 input_left_pads_{input_left_pads},
530 input_right_pads_{input_right_pads}
531 {
532 static ActiveWorkgroupsPerCU active_workgroups_per_cu;
533
536 e_g_k_c_xs_lengths.begin(), NDimSpatial + I3, 1, std::multiplies<>()) *
537 sizeof(WeiDataType);
538
539 constexpr index_t spatial_offset = 3;
540 std::copy(begin(b_g_n_c_wis_lengths) + spatial_offset,
541 end(b_g_n_c_wis_lengths),
543 std::copy(begin(e_g_k_c_xs_lengths) + spatial_offset,
544 end(e_g_k_c_xs_lengths),
546 std::copy(begin(a_g_n_k_wos_lengths) + spatial_offset,
547 end(a_g_n_k_wos_lengths),
549
550 std::array<index_t, NDimSpatial + 3> a_g_n_k_wos_strides_transposed =
551 conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(a_g_n_k_wos_lengths,
552 a_g_n_k_wos_strides);
553 std::array<index_t, NDimSpatial + 3> b_g_n_c_wis_strides_transposed =
554 conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(b_g_n_c_wis_lengths,
555 b_g_n_c_wis_strides);
556 std::array<index_t, NDimSpatial + 3> e_g_k_c_xs_strides_transposed =
557 conv_ngchw_to_nhwgc_transformer.TransposeWeiStrides(e_g_k_c_xs_lengths,
558 e_g_k_c_xs_strides);
559
560#if !DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
561 if(split_k < 0)
562 {
563 ck::index_t gemmM, gemmN;
564 std::tie(gemmM, gemmN, std::ignore) =
565 get_bwd_weight_gemm_sizes<NDimSpatial>(a_g_n_k_wos_lengths, e_g_k_c_xs_lengths);
566
567 const auto grid_size =
570 grid_size);
571 }
572 else
573#endif
574 {
575 k_batch_ = split_k;
576 }
577
578 const auto descs =
580 .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
581 Conv_N_,
582 Conv_K_,
583 Conv_C_,
587 b_g_n_c_wis_strides_transposed,
588 e_g_k_c_xs_strides_transposed,
589 a_g_n_k_wos_strides_transposed,
590 conv_filter_strides,
591 conv_filter_dilations,
592 input_left_pads,
593 input_right_pads,
594 k_batch_);
595
598 c_grid_desc_m_n_ = descs[I2];
599
602
603 // A/B/C Batch Stride
604 compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides_transposed[0];
605 compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_n_c_wis_strides_transposed[0];
606 compute_ptr_offset_of_batch_.BatchStrideC_ = e_g_k_c_xs_strides_transposed[0];
607
610 {
612 conv_ngchw_to_nhwgc_transformer.template MakeNGCHWTransposeDesc<NDimSpatial>(
613 a_g_n_k_wos_lengths, a_g_n_k_wos_strides);
615 conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc<NDimSpatial>(
616 a_g_n_k_wos_lengths, a_g_n_k_wos_strides);
617
619 conv_ngchw_to_nhwgc_transformer.template MakeNGCHWTransposeDesc<NDimSpatial>(
620 b_g_n_c_wis_lengths, b_g_n_c_wis_strides);
622 conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc<NDimSpatial>(
623 b_g_n_c_wis_lengths, b_g_n_c_wis_strides);
624
626 conv_ngchw_to_nhwgc_transformer.template MakeGKYXCTransposeDesc<NDimSpatial>(
627 e_g_k_c_xs_lengths, e_g_k_c_xs_strides);
629 conv_ngchw_to_nhwgc_transformer.template MakeGKCYXTransposeDesc<NDimSpatial>(
630 e_g_k_c_xs_lengths, e_g_k_c_xs_strides);
631
633 a_in_transpose_desc_.GetLength(I0), a_in_transpose_desc_.GetLength(I1)};
634
636 b_in_transpose_desc_.GetLength(I0), b_in_transpose_desc_.GetLength(I1)};
637
639 e_in_transpose_desc_.GetLength(I0), e_in_transpose_desc_.GetLength(I1)};
640 }
641 }
642
644 {
647 {
648 // Align to 128B
650 sizeof(ADataType) * a_in_transpose_desc_.GetElementSpaceSize(), 128) *
651 128;
652 }
653 else
654 {
655 return 0;
656 }
657 }
658
660 {
663 {
664 // Align to 128B
666 sizeof(BDataType) * b_in_transpose_desc_.GetElementSpaceSize(), 128) *
667 128;
668 }
669 else
670 {
671 return 0;
672 }
673 }
674
676 {
679 {
680 return sizeof(CDataType) * e_in_transpose_desc_.GetElementSpaceSize();
681 }
682 else
683 {
684 return 0;
685 }
686 }
687
693
700
702
705
708
711
712 // for computing batch offset
713 ComputePtrOffsetOfStridedBatch<> compute_ptr_offset_of_batch_;
714
717
718 OutElementwiseOperation a_element_op_;
719 InElementwiseOperation b_element_op_;
720 WeiElementwiseOperation c_element_op_;
721
722 // for checking IsSupportedArgument()
727 std::array<ck::index_t, NDimSpatial> input_spatial_lengths_;
728 std::array<ck::index_t, NDimSpatial> filter_spatial_lengths_;
729 std::array<ck::index_t, NDimSpatial> output_spatial_lengths_;
730 const std::array<ck::index_t, NDimSpatial>& conv_filter_strides_;
731 const std::array<ck::index_t, NDimSpatial>& input_left_pads_;
732 const std::array<ck::index_t, NDimSpatial>& input_right_pads_;
734 };
735
736 // Invoker
737 struct Invoker : public BaseInvoker
738 {
740
741 void ShowInfo(const Argument& arg)
742 {
743 std::cout << "arg.a_grid_desc_kbatch_k0_m_k1_{"
744 << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) << ", "
745 << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1) << ", "
746 << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2) << ", "
747 << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I3) << "}" << std::endl;
748
749 std::cout << "arg.b_grid_desc_kbatch_k0_n_k1_{"
750 << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I0) << ", "
751 << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1) << ", "
752 << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I2) << ", "
753 << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I3) << "}" << std::endl;
754
755 std::cout << "arg.c_grid_desc_m_n_{" << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
756 << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
757 }
758
759 template <typename GridwiseGemm>
760 float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
761 {
762 float avg_time = 0.f;
763
764 const ADataType* p_a_grid = arg.p_a_grid_;
765 const BDataType* p_b_grid = arg.p_b_grid_;
766 CDataType* p_e_grid = arg.p_c_grid_;
767
768 auto c_grid_desc_mblock_mperblock_nblock_nperblock =
770 arg.c_grid_desc_m_n_);
771
774 {
775 p_e_grid =
778 sizeof(CDataType);
779 }
780
783 {
784 const index_t grid_size_a =
785 arg.elementwise_block_2_ctile_map_transpose_a_.CalculateGridSize(
787 const index_t grid_size_b =
788 arg.elementwise_block_2_ctile_map_transpose_b_.CalculateGridSize(
790
795 BDataType* p_out_b_grid = type_convert<BDataType*>(arg.p_workspace_) +
797
798 // Different data type for A and B is not supported
799 auto kernel_transpose = kernel_elementwise_dual<GridwiseInOutTranspose,
812
813 avg_time += launch_and_time_kernel(stream_config,
814 kernel_transpose,
815 dim3(grid_size_a + grid_size_b),
816 dim3(BlockSize),
817 0,
824 make_tuple(p_out_a_grid),
825 make_tuple(p_out_b_grid),
829 grid_size_a);
830 }
831
832 const index_t grid_size =
833 arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.Conv_G_;
834
835 const auto K0 = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1);
836
837 const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
838
839 auto launch_kernel = [&](auto has_main_k_block_loop) {
840 constexpr bool has_main_loop = has_main_k_block_loop.value;
841
842 const auto kernel = kernel_batched_gemm_xdlops_bwd_weight<
843 GridwiseGemm,
844 ADataType,
845 BDataType,
846 CDataType,
847 OutElementwiseOperation,
848 InElementwiseOperation,
849 WeiElementwiseOperation,
854 ComputePtrOffsetOfStridedBatch<>,
855 has_main_loop>;
856
857 const auto clear_workspace = [&]() {
858 hip_check_error(hipMemsetAsync(
859 p_e_grid, 0, arg.c_space_size_bytes, stream_config.stream_id_));
860 };
861
863 stream_config,
864 clear_workspace,
865 kernel,
866 dim3(grid_size),
867 dim3(BlockSize),
868 0,
869 p_a_grid,
870 p_b_grid,
871 p_e_grid,
872 arg.a_element_op_,
873 arg.b_element_op_,
874 arg.c_element_op_,
875 arg.Conv_G_,
878 c_grid_desc_mblock_mperblock_nblock_nperblock,
881 };
882
883 if(has_main_k0_block_loop)
884 {
885 launch_kernel(integral_constant<bool, true>{});
886 }
887 else
888 {
889 launch_kernel(integral_constant<bool, false>{});
890 }
891
894 {
895 const index_t grid_size_e =
896 arg.elementwise_block_2_ctile_map_transpose_e_.CalculateGridSize(
898
899 const CDataType* p_e_in_grid = static_cast<const CDataType*>(p_e_grid);
900
901 // Different data type for A and B is not supported
903 ck::Tuple<GKYXCTransposeDescType>,
904 ck::Tuple<GKCYXTransposeDescType>,
905 ck::Tuple<const CDataType*>,
906 ck::Tuple<CDataType*>,
908 element_wise::PassThrough>;
909
910 avg_time += launch_and_time_kernel(stream_config,
911 kernel_transpose,
912 dim3(grid_size_e),
913 dim3(BlockSize),
914 0,
917 make_tuple(p_e_in_grid),
920 element_wise::PassThrough{});
921 }
922
923 return avg_time;
924 }
925
927
928 float Run(const BaseArgument* p_arg,
929 const StreamConfig& stream_config = StreamConfig{}) override
930 {
931 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
932 }
933 };
934
935 static constexpr bool IsValidCompilationParameter()
936 {
937 // TODO: properly implement this check
938 return true;
939 }
940
941 static bool IsSupportedArgument(const Argument& arg)
942 {
943#if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
944 if(arg.k_batch_ < 0)
945 {
946 return false;
947 }
948#endif
950 {
951 return false;
952 }
953 if(!is_bf16_atomic_supported() && std::is_same_v<CDataType, ck::bhalf_t>)
954 {
955 return false;
956 }
958 {
959 if(!is_tf32_supported())
960 {
961 return false;
962 }
964 {
965 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
966 {
967 std::cout << "ComputeDataType for A and B should be same while using TF32"
968 << std::endl;
969 }
970 return false;
971 }
972 }
973 if constexpr(NDimSpatial == 1)
974 {
976 {
977 return false;
978 }
979 }
980 else if constexpr(NDimSpatial == 2)
981 {
985 {
986 return false;
987 }
988 }
989 else if constexpr(NDimSpatial == 3)
990 {
994 {
995 return false;
996 }
997 }
998 else
999 {
1000 return false;
1001 }
1002
1003 if constexpr(ConvBackwardWeightSpecialization ==
1005 {
1006 // check if it's 1x1, stride=1 pad = 0 conv
1007 for(int i = 0; i < NDimSpatial; i++)
1008 {
1009 if(!(arg.filter_spatial_lengths_[i] == 1 && arg.conv_filter_strides_[i] == 1 &&
1010 arg.input_left_pads_[i] == 0 && arg.input_right_pads_[i] == 0))
1011 {
1012 return false;
1013 }
1014 }
1015 }
1016
1017 // vector load A/B matrix from global memory
1018 if(!(ABlockTransferSrcVectorDim == 2 && BBlockTransferSrcVectorDim == 2 &&
1019 arg.Conv_K_ % ABlockTransferSrcScalarPerVector == 0 &&
1020 arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0))
1021 {
1022 return false;
1023 }
1024
1025 // vector store C matrix into global memory
1026 if(!(arg.Conv_C_ % CBlockTransferScalarPerVector_NWaveNPerXdl == 0))
1027 {
1028 return false;
1029 }
1030
1033 {
1035 {
1036 return false;
1037 }
1038
1040 {
1041 return false;
1042 }
1043
1044 const index_t input_spatial_acum = ck::accumulate_n<index_t>(
1045 arg.input_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>());
1046 const index_t output_spatial_acum = ck::accumulate_n<index_t>(
1047 arg.output_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>());
1048
1049 if(input_spatial_acum % TransposeTransferSrcScalarPerVectorAligned != 0)
1050 {
1051 return false;
1052 }
1053
1054 if(output_spatial_acum % TransposeTransferSrcScalarPerVectorAligned != 0)
1055 {
1056 return false;
1057 }
1058
1059 if(!arg.p_workspace_)
1060 {
1061 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1062 {
1063 std::cout << "Warning: Workspace for "
1064 "DeviceGroupedConvBwdWeight_Xdl_CShuffle::Argument is not "
1065 "allocated, use SetWorkSpacePointer."
1066 << std::endl;
1067 }
1068 return false;
1069 }
1070
1071 constexpr long_index_t TwoGB = (long_index_t{1} << 31);
1072 if(!(arg.a_out_transpose_desc_.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB &&
1073 arg.b_out_transpose_desc_.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB))
1074 {
1075 return false;
1076 }
1077 }
1078
1079 // Gridwise GEMM size
1080 if(get_warp_size() == 64)
1081 {
1082 if constexpr(NXdlPerWave64 > 0)
1083 {
1086 arg.c_grid_desc_m_n_,
1087 arg.block_2_ctile_map_);
1088 }
1089 }
1090 else
1091 {
1092 if constexpr(NXdlPerWave32 > 0)
1093 {
1096 arg.c_grid_desc_m_n_,
1097 arg.block_2_ctile_map_);
1098 }
1099 }
1100 return false;
1101 }
1102
1103 bool IsSupportedArgument(const BaseArgument* p_arg) override
1104 {
1105 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
1106 }
1107
1108 static auto
1109 MakeArgument(const InDataType* p_in_grid,
1110 WeiDataType* p_wei_grid,
1111 const OutDataType* p_out_grid,
1112 const std::array<index_t, NDimSpatial + 3>& b_g_n_c_wis_lengths, // input
1113 const std::array<index_t, NDimSpatial + 3>& b_g_n_c_wis_strides,
1114 const std::array<index_t, NDimSpatial + 3>& e_g_k_c_xs_lengths, // weight
1115 const std::array<index_t, NDimSpatial + 3>& e_g_k_c_xs_strides,
1116 const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_lengths, // output
1117 const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_strides,
1118 const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
1119 const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
1120 const std::array<ck::index_t, NDimSpatial>& input_left_pads,
1121 const std::array<ck::index_t, NDimSpatial>& input_right_pads,
1122 InElementwiseOperation in_element_op,
1123 WeiElementwiseOperation wei_element_op,
1124 OutElementwiseOperation out_element_op,
1125 const ck::index_t split_k)
1126 {
1127 return Argument{p_in_grid,
1128 p_wei_grid,
1129 p_out_grid,
1130 b_g_n_c_wis_lengths, // input
1131 b_g_n_c_wis_strides,
1132 e_g_k_c_xs_lengths, // weight
1133 e_g_k_c_xs_strides,
1134 a_g_n_k_wos_lengths, // output
1135 a_g_n_k_wos_strides,
1136 conv_filter_strides,
1137 conv_filter_dilations,
1138 input_left_pads,
1139 input_right_pads,
1140 1,
1141 1,
1142 in_element_op,
1143 wei_element_op,
1144 out_element_op,
1145 split_k};
1146 }
1147
1148 static auto MakeInvoker() { return Invoker{}; }
1149
1150 std::unique_ptr<BaseArgument>
1151 MakeArgumentPointer(const void* p_in_grid,
1152 void* p_wei_grid,
1153 const void* p_out_grid,
1154 const std::array<index_t, NDimSpatial + 3>& b_g_n_c_wis_lengths, // input
1155 const std::array<index_t, NDimSpatial + 3>& b_g_n_c_wis_strides,
1156 const std::array<index_t, NDimSpatial + 3>& e_g_k_c_xs_lengths, // weight
1157 const std::array<index_t, NDimSpatial + 3>& e_g_k_c_xs_strides,
1158 const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_lengths, // output
1159 const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_strides,
1160 const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
1161 const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
1162 const std::array<ck::index_t, NDimSpatial>& input_left_pads,
1163 const std::array<ck::index_t, NDimSpatial>& input_right_pads,
1164 InElementwiseOperation in_element_op,
1165 WeiElementwiseOperation wei_element_op,
1166 OutElementwiseOperation out_element_op,
1167 const ck::index_t split_k) override
1168 {
1169 return std::make_unique<Argument>(static_cast<const InDataType*>(p_in_grid),
1170 static_cast<WeiDataType*>(p_wei_grid),
1171 static_cast<const OutDataType*>(p_out_grid),
1172 b_g_n_c_wis_lengths, // input
1173 b_g_n_c_wis_strides,
1174 e_g_k_c_xs_lengths, // weight
1175 e_g_k_c_xs_strides,
1176 a_g_n_k_wos_lengths, // output
1177 a_g_n_k_wos_strides,
1178 conv_filter_strides,
1179 conv_filter_dilations,
1180 input_left_pads,
1181 input_right_pads,
1182 1,
1183 1,
1184 in_element_op,
1185 wei_element_op,
1186 out_element_op,
1187 split_k);
1188 }
1189
1190 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
1191 {
1192 return std::make_unique<Invoker>(Invoker{});
1193 }
1194
1195 std::string GetTypeString() const override
1196 {
1197 auto str = std::stringstream();
1198
1199 // clang-format off
1200 str << "DeviceGroupedConvBwdWeight_Xdl_CShuffle"
1201 << "<"
1202 << BlockSize << ", "
1203 << MPerBlock << ", "
1204 << NPerBlock << ", "
1205 << K0PerBlock << ", "
1206 << getConvBackwardWeightSpecializationString(ConvBackwardWeightSpecialization) << ", "
1207 << K1 << ", "
1208 << MXdlPerWave << ", "
1209 << NXdlPerWave << ", "
1210 << ABlockTransferSrcScalarPerVector << ", "
1211 << ABlockTransferDstScalarPerVector_K1 << ", "
1212 << BBlockTransferSrcScalarPerVector << ", "
1213 << BBlockTransferDstScalarPerVector_K1 << ", "
1214 << CShuffleMXdlPerWavePerShuffle << ", "
1215 << CShuffleNXdlPerWavePerShuffle << ", "
1216 << CBlockTransferScalarPerVector_NWaveNPerXdl;
1217
1220 str << ", TransposeTransferSrcScalarPerVectorAligned: "
1222 << "TransposeTransferDstScalarPerVectorAligned: " << TransposeTransferDstScalarPerVectorAligned;
1223 }
1224
1225
1226 str << ">";
1227 // clang-format on
1228
1229 return str.str();
1230 }
1231
1232#ifdef CK_EXPERIMENTAL_BUILDER
1233 std::string GetInstanceString() const override
1234 {
1235 static_assert(ck_tile::reflect::HasInstanceTraits<DeviceOp>,
1236 "Specialization of instance_traits not found. Please check that a "
1237 "specialization exists in file "
1238 "ck_tile/builder/reflect/"
1239 "instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp "
1240 "for the given template parameters.");
1241 return ck_tile::reflect::instance_string<DeviceOp>();
1242 }
1243#endif
1244
1245 size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
1246 {
1247 auto arg = dynamic_cast<const Argument*>(p_arg);
1248 if(arg)
1249 {
1250 return arg->GetWorkspaceSizeBytes();
1251 }
1252 else
1253 throw std::runtime_error(
1254 "The argument pointer is not an object of "
1255 "DeviceGroupedConvBwdWeight_Xdl_CShuffle::Argument structure!");
1256 }
1257
1259 void* p_workspace,
1260 const StreamConfig& = StreamConfig{}) const override
1261 {
1262 auto p_arg_ = dynamic_cast<Argument*>(p_arg);
1263 if(p_arg_)
1264 {
1265 p_arg_->p_workspace_ = p_workspace;
1266 }
1267 else
1268 throw std::runtime_error(
1269 "The argument pointer is not an object of "
1270 "DeviceGroupedConvBwdWeight_Xdl_CShuffle::Argument structure!");
1271 }
1272};
1273
1274} // namespace device
1275} // namespace tensor_operation
1276} // namespace ck
#define CK_MIN_BLOCK_PER_CU
Definition ck.hpp:31
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
#define GET_NXDL_PER_WAVE_IMPL
Definition device_base.hpp:81
#define INVOKER_RUN_IMPL
Definition device_base.hpp:94
void hip_check_error(hipError_t x)
Definition host_utility/hip_check_error.hpp:10
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
float launch_and_time_kernel_with_preprocess(const StreamConfig &stream_config, PreProcessFunc preprocess, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:91
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
Definition convolution_backward_data_specialization.hpp:8
auto get_bwd_weight_gemm_sizes(const std::array< index_t, NDimSpatial+3 > &a_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &e_g_k_c_xs_lengths)
Definition split_k_utils.hpp:55
ConvolutionBackwardWeightSpecialization
Definition convolution_backward_weight_specialization.hpp:13
@ Filter1x1Stride1Pad0
Definition convolution_backward_weight_specialization.hpp:15
constexpr bool is_GNWC_GKXC_GNWK()
Definition device_grouped_conv_utils.hpp:23
__global__ void kernel_batched_gemm_xdlops_bwd_weight(const FloatA *__restrict__ p_a_grid, const FloatB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CElementwiseOperation c_element_op, const index_t batch_count, const AGridDesc_B_K0_M_K1 a_b_k0_m_k1_grid_desc, const BGridDesc_B_K0_N_K1 b_b_k0_n_k1_grid_desc, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock, const Block2CTileMap block_2_ctile_map, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
Definition device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp:50
constexpr bool is_GNDHWC_GKZYXC_GNDHWK()
Definition device_grouped_conv_utils.hpp:88
constexpr bool is_NHWGC_GKYXC_NHWGK()
Definition device_grouped_conv_utils.hpp:40
ck::index_t get_best_occupancy_k_batch_value(int max_occupancy, ck::index_t grid_size)
Definition split_k_utils.hpp:30
constexpr bool is_NDHWGC_GKZYXC_NDHWGK()
Definition device_grouped_conv_utils.hpp:80
constexpr bool is_NGCDHW_NGKDHW()
Definition device_grouped_conv_utils.hpp:112
constexpr bool is_NGCHW_GKCYX_NGKHW()
Definition device_grouped_conv_utils.hpp:64
std::string getConvBackwardWeightSpecializationString(const ConvolutionBackwardWeightSpecialization &s)
Definition convolution_backward_weight_specialization.hpp:21
ck::index_t calculate_mn_grid_size(ck::index_t gemmM, ck::index_t gemmN)
Definition split_k_utils.hpp:84
constexpr bool is_GNHWC_GKYXC_GNHWK()
Definition device_grouped_conv_utils.hpp:48
constexpr bool is_NGCDHW_GKCZYX_NGKDHW()
Definition device_grouped_conv_utils.hpp:104
constexpr bool is_NGCHW_NGKHW()
Definition device_grouped_conv_utils.hpp:72
Definition convolution_backward_data_specialization.hpp:7
CK_TILE_HOST float launch_kernel(const stream_config &s, Callables &&... callables)
Definition tile/host/kernel_launch.hpp:173
Definition ck.hpp:268
__device__ index_t get_grid_size()
Definition get_id.hpp:49
int32_t index_t
Definition ck.hpp:299
@ AtomicAdd
Definition ck.hpp:279
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
integral_constant< index_t, N > Number
Definition number.hpp:12
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
__host__ __device__ constexpr Y type_convert(X x)
Definition utility/type_convert.hpp:98
bool is_xdl_wmma_supported()
Definition host_utility/device_prop.hpp:76
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
bool EnvIsEnabled(EnvVar)
Definition utility/env.hpp:140
bool is_tf32_supported()
Definition host_utility/device_prop.hpp:132
constexpr bool is_same_v
Definition type.hpp:283
typename remove_reference< T >::type remove_reference_t
Definition type.hpp:292
auto accumulate_n(ForwardIterator first, Size count, T init, BinaryOperation op) -> decltype(std::accumulate(first, std::next(first, count), init, op))
Definition library/utility/numeric.hpp:11
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__global__ void kernel_elementwise_dual(const InAGridDescTuple in_grid_desc_tuple_a, const InBGridDescTuple in_grid_desc_tuple_b, const OutAGridDescTuple out_grid_desc_tuple_a, const OutBGridDescTuple out_grid_desc_tuple_b, const InADataTypePointerTuple p_in_global_tuple_a, const InBDataTypePointerTuple p_in_global_tuple_b, const OutADataTypePointerTuple p_out_global_tuple_a, const OutBDataTypePointerTuple p_out_global_tuple_b, const Block2TileMapA block_2_tile_map_a, const Block2TileMapB block_2_tile_map_b, const ElementwiseOperation elementwise_op, const index_t a_grid_size)
Definition gridwise_elementwise_2d.hpp:61
int64_t long_index_t
Definition ck.hpp:300
@ v1
Definition gridwise_gemm_pipeline_selector.hpp:19
__global__ void kernel_elementwise(const InGridDescTuple in_grid_desc_tuple, const OutGridDescTuple out_grid_desc_tuple, const InDataTypePointerTuple p_in_global_tuple, const OutDataTypePointerTuple p_out_global_tuple, const Block2TileMap block_2_tile_map, const ElementwiseOperation elementwise_op)
Definition gridwise_elementwise_2d.hpp:29
bool is_bf16_atomic_supported()
Definition host_utility/device_prop.hpp:108
Definition ck/stream_config.hpp:10
Definition block_to_ctile_map.hpp:261
Definition gridwise_elementwise_2d.hpp:278
Definition gridwise_gemm_xdlops_bwd_weight.hpp:254
Definition utility/sequence.hpp:43
Definition utility/tuple.hpp:117
Definition tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp:24
Definition transform_conv_ngchw_to_nhwgc.hpp:31
index_t k_batch_
Definition split_k_arg.hpp:12
Definition device_base.hpp:197
void * p_workspace_
Definition device_base.hpp:204
virtual std::string GetInstanceString() const
Definition device_base.hpp:230
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:440
int max_occupancy_
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:484
static int GetMaxOccupancy()
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:442
ActiveWorkgroupsPerCU()
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:466
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:488
WeiElementwiseOperation c_element_op_
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:720
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:713
const std::array< ck::index_t, NDimSpatial > & input_left_pads_
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:731
Argument(const InDataType *p_in_grid, WeiDataType *p_wei_grid, const OutDataType *p_out_grid, const std::array< index_t, NDimSpatial+3 > &b_g_n_c_wis_lengths, const std::array< index_t, NDimSpatial+3 > &b_g_n_c_wis_strides, const std::array< index_t, NDimSpatial+3 > &e_g_k_c_xs_lengths, const std::array< index_t, NDimSpatial+3 > &e_g_k_c_xs_strides, const std::array< index_t, NDimSpatial+3 > &a_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &a_g_n_k_wos_strides, const std::array< ck::index_t, NDimSpatial > &conv_filter_strides, const std::array< ck::index_t, NDimSpatial > &conv_filter_dilations, const std::array< ck::index_t, NDimSpatial > &input_left_pads, const std::array< ck::index_t, NDimSpatial > &input_right_pads, const ck::index_t M01, const ck::index_t N01, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op, ck::index_t split_k)
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:489
BGridDesc_K0_N_K1 b_grid_desc_kbatch_k0_n_k1_
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:698
NHWGCTransposeDescType b_out_transpose_desc_
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:707
OutElementwiseOperation a_element_op_
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:718
CGridDesc_M_N c_grid_desc_m_n_
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:699
std::size_t GetWorkspaceATensorSizeBytes() const
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:643
std::array< ck::index_t, NDimSpatial > output_spatial_lengths_
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:729
std::array< ck::index_t, NDimSpatial > input_spatial_lengths_
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:727
InElementwiseOperation b_element_op_
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:719
const index_t Conv_G_
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:723
Block2CTileMap block_2_ctile_map_
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:701
NGCHWTransposeDescType b_in_transpose_desc_
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:706
long_index_t c_space_size_bytes
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:733
std::size_t GetWorkspaceSizeBytes() const
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:688
const ADataType * p_a_grid_
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:694
const std::array< ck::index_t, NDimSpatial > & conv_filter_strides_
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:730
index_t N01_
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:716
const index_t Conv_K_
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:725
NGCHWTransposeDescType a_in_transpose_desc_
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:706
CDataType * p_c_grid_
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:696
index_t M01_
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:715
NHWGCTransposeDescType a_out_transpose_desc_
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:707
Block2TileMapTranspose elementwise_block_2_ctile_map_transpose_a_
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:703
GKCYXTransposeDescType e_out_transpose_desc_
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:710
std::array< ck::index_t, NDimSpatial > filter_spatial_lengths_
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:728
Block2TileMapTranspose elementwise_block_2_ctile_map_transpose_b_
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:704
const std::array< ck::index_t, NDimSpatial > & input_right_pads_
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:732
std::size_t GetWorkspaceBTensorSizeBytes() const
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:659
Block2TileMapTranspose elementwise_block_2_ctile_map_transpose_e_
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:704
GKYXCTransposeDescType e_in_transpose_desc_
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:709
AGridDesc_K0_M_K1 a_grid_desc_kbatch_k0_m_k1_
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:697
const BDataType * p_b_grid_
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:695
const index_t Conv_C_
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:726
const index_t Conv_N_
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:724
std::size_t GetWorkspaceETensorSizeBytes() const
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:675
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:738
DeviceOp::Argument Argument
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:739
void ShowInfo(const Argument &arg)
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:741
INVOKER_RUN_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:928
float RunImp(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:760
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:170
static constexpr auto BBlockLdsN0PerBlock
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:220
void SetWorkSpacePointer(BaseArgument *p_arg, void *p_workspace, const StreamConfig &=StreamConfig{}) const override
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:1258
GridwiseElementwise< Tuple< GKYXCTransposeDescType >, Tuple< GKCYXTransposeDescType >, Tuple< const CDataType * >, Tuple< CDataType * >, Block2TileMapTranspose, element_wise::PassThrough, BlockSize, MPerBlock, NPerBlock, MPerBlock/ClusterLengthMPerBlock, NPerBlock/ClusterLengthNPerBlock, Sequence< 1, 0 >, Sequence< CBlockTransferScalarPerVector_NWaveNPerXdl >, Sequence< 1 >, I1, I0 > GridwiseElementwiseWeightTranspose
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:357
OutElementwiseOperation AElementwiseOperation
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:185
decltype(GetABCGridDesc< NDimSpatial >()) ABCGridDescs
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:298
InElementwiseOperation BElementwiseOperation
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:186
static constexpr auto I2
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:194
static constexpr auto ABlockLdsM1PerBlock
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:214
static constexpr index_t TransposeTransferSrcScalarPerVectorAligned
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:319
remove_cvref_t< decltype(conv_ngchw_to_nhwgc_transformer .template MakeNHWGCTransposeDesc< NDimSpatial >({}, {}))> NHWGCTransposeDescType
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:327
static constexpr auto NXdlPerWave32
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:174
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:1103
static constexpr index_t TransposeTransferDstScalarPerVectorAligned
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:321
BlockToCTileMap_M00_N0_M01Adapt< MPerBlock, NPerBlock > Block2TileMapTranspose
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:317
static constexpr auto I3
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:195
InDataType BDataType
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:177
static constexpr auto ElePerBank
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:211
static constexpr index_t ClusterLengthMPerBlock
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:304
static constexpr auto I5
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:197
GridwiseGemmBase< NXdlPerWave32 > GridwiseGemm32
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:430
remove_cvref_t< decltype(conv_ngchw_to_nhwgc_transformer .template MakeGKYXCTransposeDesc< NDimSpatial >({}, {}))> GKYXCTransposeDescType
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:333
static constexpr auto I0
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:192
GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight< BlockSize, ADataType, BDataType, AccDataType, CDataType, InMemoryDataOperationEnum::AtomicAdd, AGridDesc_K0_M_K1, BGridDesc_K0_N_K1, CGridDesc_M_N, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, MPerBlock, NPerBlock, K0PerBlock, MPerXDL, NPerXDL, K1, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, ABlockLdsAddExtraM, ABlockLdsM1PerBlock, ABlockLdsM0PerBlock, ABlockLdsM1Padding, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, BBlockLdsAddExtraN, BBlockLdsN1PerBlock, BBlockLdsN0PerBlock, BBlockLdsN1Padding, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CBlockTransferScalarPerVector_NWaveNPerXdl, CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, true, true, 1, PipelineVersion::v1, ComputeTypeA, ComputeTypeB > GridwiseGemmBase
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:376
decltype(GridwiseGemm64::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{})) CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:433
decltype(GridwiseGemm64::MakeCBlockClusterAdaptor(CGridDesc_M_N{}, 1, 1, 1)) Block2CTileMap
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:436
static constexpr bool IsValidCompilationParameter()
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:935
remove_cvref_t< decltype(ABCGridDescs{}[I2])> CGridDesc_M_N
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:302
remove_cvref_t< decltype(conv_ngchw_to_nhwgc_transformer .template MakeNGCHWTransposeDesc< NDimSpatial >({}, {}))> NGCHWTransposeDescType
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:324
GridwiseElementwise< Tuple< NGCHWTransposeDescType >, Tuple< NHWGCTransposeDescType >, Tuple< const ADataType * >, Tuple< ADataType * >, Block2TileMapTranspose, element_wise::PassThrough, BlockSize, MPerBlock, NPerBlock, MPerBlock/ClusterLengthMPerBlock, NPerBlock/ClusterLengthNPerBlock, Sequence< 1, 0 >, Sequence< TransposeTransferSrcScalarPerVectorAligned >, Sequence< TransposeTransferDstScalarPerVectorAligned >, I1, I0 > GridwiseInOutTranspose
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:337
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:429
static constexpr auto conv_ngchw_to_nhwgc_transformer
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:309
static constexpr auto I1
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:193
static constexpr auto I4
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:196
remove_cvref_t< decltype(conv_ngchw_to_nhwgc_transformer .template MakeGKCYXTransposeDesc< NDimSpatial >({}, {}))> GKCYXTransposeDescType
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:330
static constexpr auto BBlockLdsN1Padding
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:221
size_t GetWorkSpaceSize(const BaseArgument *p_arg) const override
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:1245
static auto MakeInvoker()
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:1148
static constexpr auto BBlockLdsN1PerBlock
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:219
static bool IsSupportedArgument(const Argument &arg)
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:941
remove_cvref_t< decltype(ABCGridDescs{}[I1])> BGridDesc_K0_N_K1
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:301
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:1190
static constexpr auto ABlockLdsM1Padding
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:216
WeiDataType CDataType
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:178
std::string GetTypeString() const override
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:1195
static constexpr auto BankLength
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:210
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_in_grid, void *p_wei_grid, const void *p_out_grid, const std::array< index_t, NDimSpatial+3 > &b_g_n_c_wis_lengths, const std::array< index_t, NDimSpatial+3 > &b_g_n_c_wis_strides, const std::array< index_t, NDimSpatial+3 > &e_g_k_c_xs_lengths, const std::array< index_t, NDimSpatial+3 > &e_g_k_c_xs_strides, const std::array< index_t, NDimSpatial+3 > &a_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &a_g_n_k_wos_strides, const std::array< ck::index_t, NDimSpatial > &conv_filter_strides, const std::array< ck::index_t, NDimSpatial > &conv_filter_dilations, const std::array< ck::index_t, NDimSpatial > &input_left_pads, const std::array< ck::index_t, NDimSpatial > &input_right_pads, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op, const ck::index_t split_k) override
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:1151
static constexpr auto K1Number
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:199
InDataType ABDataType
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:190
static constexpr index_t ClusterLengthNPerBlock
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:306
remove_cvref_t< decltype(ABCGridDescs{}[I0])> AGridDesc_K0_M_K1
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:300
static constexpr auto ABlockLdsM0PerBlock
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:215
WeiElementwiseOperation CElementwiseOperation
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:187
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:173
DeviceGroupedConvBwdWeight_Xdl_CShuffle DeviceOp
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:171
static constexpr auto conv_to_gemm_transformer
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:201
static auto MakeArgument(const InDataType *p_in_grid, WeiDataType *p_wei_grid, const OutDataType *p_out_grid, const std::array< index_t, NDimSpatial+3 > &b_g_n_c_wis_lengths, const std::array< index_t, NDimSpatial+3 > &b_g_n_c_wis_strides, const std::array< index_t, NDimSpatial+3 > &e_g_k_c_xs_lengths, const std::array< index_t, NDimSpatial+3 > &e_g_k_c_xs_strides, const std::array< index_t, NDimSpatial+3 > &a_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &a_g_n_k_wos_strides, const std::array< ck::index_t, NDimSpatial > &conv_filter_strides, const std::array< ck::index_t, NDimSpatial > &conv_filter_dilations, const std::array< ck::index_t, NDimSpatial > &input_left_pads, const std::array< ck::index_t, NDimSpatial > &input_right_pads, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op, const ck::index_t split_k)
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:1109
OutDataType ADataType
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:176
static auto GetABCGridDesc()
Definition device_grouped_conv_bwd_weight_xdl_cshuffle.hpp:224
Definition device_grouped_conv_bwd_weight.hpp:29
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340
#define CK_ENV(name)
Definition utility/env.hpp:129