26template <
typename GridwiseGemm,
27 bool HasMainKBlockLoop,
32#if CK_USE_LAUNCH_BOUNDS
38#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
39 if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
41 __shared__
char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
43 auto splitk_batch_offset =
typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
45 GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
46 karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
47 karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
48 karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
49 karg.p_b_scale_grid + splitk_batch_offset.scale_k_split_offset,
58template <
typename GridwiseGemm,
59 bool HasMainKBlockLoop,
64#if CK_USE_LAUNCH_BOUNDS
70#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
71 if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
75 __shared__
char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
76 __shared__
char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
78 auto splitk_batch_offset =
typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
80 GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
81 karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
82 karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
83 karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
84 karg.p_b_scale_grid + splitk_batch_offset.scale_k_split_offset,
94template <
typename ALayout,
100 typename CShuffleDataType,
102 typename AElementwiseOperation,
103 typename BElementwiseOperation,
104 typename CElementwiseOperation,
118 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
119 typename ABlockTransferThreadClusterArrangeOrder,
120 typename ABlockTransferSrcAccessOrder,
121 index_t ABlockTransferSrcVectorDim,
122 index_t ABlockTransferSrcScalarPerVector,
123 index_t ABlockTransferDstScalarPerVector_AK1,
124 bool AThreadTransferSrcResetCoordinateAfterRun,
126 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
127 typename BBlockTransferThreadClusterArrangeOrder,
128 typename BBlockTransferSrcAccessOrder,
129 index_t BBlockTransferSrcVectorDim,
130 index_t BBlockTransferSrcScalarPerVector,
131 index_t BBlockTransferDstScalarPerVector_BK1,
132 bool BThreadTransferSrcResetCoordinateAfterRun,
134 index_t CShuffleMXdlPerWavePerShuffle,
135 index_t CShuffleNXdlPerWavePerShuffle,
136 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
137 index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
140 typename ComputeTypeA = CDataType,
141 typename ComputeTypeB = ComputeTypeA,
142 bool PermuteA =
false,
143 bool PermuteB =
false>
175 MfmaSelector<ComputeTypeA,
220 auto K_t = K_Batch * KPerBlock;
221 return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
226 auto K_t = K_Batch * KPerBlock;
227 return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
232 auto K_t = K_Batch * KPerBlock;
233 return (K + K_t - 1) / K_t * KPerBlock;
239 auto K_t = K_Batch * KReadVec;
240 return (K + K_t - 1) / K_t * KReadVec;
253 template <index_t MNXdlPerWave, index_t MNWaves, index_t MNPerXdl,
typename TileDesc_K0_MN_K1>
271 const auto a_grid_desc_mraw_kraw = [&]() {
284 if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
285 GemmSpec == GemmSpecialization::MNKPadding)
288 const auto a_grid_desc_m_k =
302 return a_grid_desc_ak0_m_ak1;
304 else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
305 GemmSpec == GemmSpecialization::MNPadding)
309 a_grid_desc_mraw_kraw,
315 return a_grid_desc_ak0_m_ak1;
317 else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
318 GemmSpec == GemmSpecialization::NKPadding)
322 a_grid_desc_mraw_kraw,
334 return a_grid_desc_ak0_m_ak1;
340 a_grid_desc_mraw_kraw,
346 return a_grid_desc_ak0_m_ak1;
353 const auto b_grid_desc_nraw_kraw = [&]() {
367 GemmSpec != GemmSpecialization::Default),
368 "pk_i4_t does not support padding");
370 if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
371 GemmSpec == GemmSpecialization::MNKPadding)
374 const auto b_grid_desc_n_k =
388 return b_grid_desc_bk0_n_bk1;
390 else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
391 GemmSpec == GemmSpecialization::MNPadding)
395 b_grid_desc_nraw_kraw,
401 return b_grid_desc_bk0_n_bk1;
403 else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
404 GemmSpec == GemmSpecialization::MKPadding)
408 b_grid_desc_nraw_kraw,
420 return b_grid_desc_bk0_n_bk1;
424 if constexpr(!PermuteB)
428 b_grid_desc_nraw_kraw,
434 return b_grid_desc_bk0_n_bk1;
439 constexpr index_t BK01 = KPerBlock / BK1Value;
441 const index_t BK0_ = StrideB / BK1Value;
442 const index_t BK00 = BK0_ / BK01;
444 const auto b_grid_desc_bk00_n_bk01_bk1_permute =
448 b_grid_desc_bk00_n_bk01_bk1_permute,
455 return b_grid_desc_bk0_n_bk1_permute;
460 template <
typename ABlockDesc_AK0_M_AK1>
461 __host__ __device__
static constexpr auto
464 constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
469 template <
typename BBlockDesc_BK0_N_BK1>
470 __host__ __device__
static constexpr auto
473 constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
478 __host__ __device__
static auto
481 const auto c_grid_desc_mraw_nraw = [&]() {
501 if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
502 GemmSpec == GemmSpecialization::MNKPadding)
511 else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
512 GemmSpec == GemmSpecialization::MKPadding)
516 c_grid_desc_mraw_nraw,
521 else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
522 GemmSpec == GemmSpecialization::NKPadding)
526 c_grid_desc_mraw_nraw,
534 return c_grid_desc_mraw_nraw;
570 std::cout <<
"problem {" <<
"M:" <<
M <<
", " <<
"N:" <<
N <<
", " <<
"K:" <<
K <<
", "
574 <<
", " <<
"AK0:" <<
AK0 <<
", " <<
"BK0:" <<
BK0 <<
", "
575 <<
"MBlock: " <<
MBlock <<
", " <<
"NBlock: " <<
NBlock <<
"}" << std::endl;
600 const BDataType* p_b_grid_,
601 CDataType* p_c_grid_,
614 bool is_reduce_ =
false)
615 :
Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_, StrideScaleB_, k_batch_},
668 if constexpr(!PermuteB)
674 const int k0_offset = karg.
KRead * karg.
N;
689 if(k_id < (karg.
KBatch - 1))
716 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
717 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
718 constexpr index_t WaveSize = BlockSize / (MWave * NWave);
731 constexpr auto MLdsLayer = LdsSize < 1 ? 1 : LdsSize;
746 a_lds_block_desc_permuted,
754 a_lds_block_desc_ak0_mldslayer_m_ak1,
762 return a_lds_block_desc_ak0_m_ak1;
769 constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I1);
770 constexpr auto M1 = MPerBlock / M0;
772 constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I0);
773 constexpr auto K0PerThreadWrite =
AK0Number / KThreadWrite;
774 constexpr auto KThreadRead = WaveSize / MPerXdl;
775 constexpr auto K0PerThreadRead =
AK0Number / KThreadRead;
777 constexpr auto kfold = (
AK1Number * M0 *
sizeof(ADataType) > 128)
779 : 128 / (
AK1Number * M0 *
sizeof(ADataType));
780 constexpr auto KThreadReadPerm =
781 (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
782 ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
786 constexpr auto mpair = (
AK1Number * MPerXdl *
sizeof(ADataType) > 128)
788 : ((128 / (
AK1Number * MPerXdl *
sizeof(ADataType))) > M0
790 : 128 / (
AK1Number * MPerXdl *
sizeof(ADataType)));
796 Number<kfold * M0 / mpair>{},
815 a_lds_block_desc_permuted,
837 a_lds_block_desc_unmerged,
840 Number<KThreadWrite / kfold / KThreadReadPerm>{},
849 return a_lds_block_desc_ak0_m_ak1;
855 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
856 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
857 constexpr index_t WaveSize = BlockSize / (MWave * NWave);
869 constexpr index_t NLdsLayer = LdsSize < 1 ? 1 : LdsSize;
884 b_lds_block_desc_permuted,
892 b_lds_block_desc_bk0_nldslayer_n_bk1,
900 return b_lds_block_desc_bk0_n_bk1;
904 constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(
I1);
905 constexpr auto N1 = NPerBlock / N0;
907 constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(
I0);
908 constexpr auto K0PerThreadWrite =
BK0Number / KThreadWrite;
909 constexpr auto KThreadRead = WaveSize / NPerXdl;
910 constexpr auto K0PerThreadRead =
BK0Number / KThreadRead;
912 constexpr auto kfold = (
BK1Number * N0 *
sizeof(BDataType) > 128)
914 : 128 / (
BK1Number * N0 *
sizeof(BDataType));
915 constexpr auto KThreadReadPerm =
916 (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
917 ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
921 constexpr auto npair = (
BK1Number * NPerXdl *
sizeof(BDataType) > 128)
923 : ((128 / (
BK1Number * NPerXdl *
sizeof(BDataType))) > N0
925 : 128 / (
BK1Number * NPerXdl *
sizeof(BDataType)));
931 Number<kfold * N0 / npair>{},
950 b_lds_block_desc_permuted,
972 b_lds_block_desc_unmerged,
975 Number<KThreadWrite / kfold / KThreadReadPerm>{},
984 return b_lds_block_desc_bk0_n_bk1;
990 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
991 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
993 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1000 return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
1018 ABlockTransferSrcScalarPerVector,
1019 BBlockTransferSrcScalarPerVector,
1039 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1042 b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
1045 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1048 constexpr auto c_block_size =
1049 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
1052 b_block_space_size_aligned *
sizeof(BDataType) /
BPackedSize),
1053 c_block_size *
sizeof(CShuffleDataType));
1061 static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
1062 (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
1063 "Invalid tuning param!");
1071 if(!(karg.M % MPerBlock == 0))
1075 std::cout <<
"Arg M value is not a multiple of MPerBlock! M: " << karg.M <<
" "
1076 << __FILE__ <<
":" << __LINE__ <<
", in function: " << __func__
1089 if(!(karg.N % NPerBlock == 0))
1093 std::cout <<
"Arg N value is not a multiple of NPerBlock! N: " << karg.N <<
" "
1094 << __FILE__ <<
":" << __LINE__ <<
", in function: " << __func__
1107 auto K_t = karg.KBatch * KPerBlock;
1108 if(!(karg.K % K_t == 0))
1112 std::cout <<
"Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
1113 << karg.K <<
" " << __FILE__ <<
":" << __LINE__
1114 <<
", in function: " << __func__ << std::endl;
1122 auto K_t = karg.KBatch * KReadVec;
1124 if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K)
1132 if(karg.K % ABlockTransferSrcScalarPerVector != 0)
1136 std::cout <<
"Arg K (" << karg.K
1137 <<
") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1138 << ABlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1139 << __LINE__ <<
", in function: " << __func__ << std::endl;
1146 if(karg.M % ABlockTransferSrcScalarPerVector != 0)
1150 std::cout <<
"Arg M (" << karg.M
1151 <<
") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1152 << ABlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1153 << __LINE__ <<
", in function: " << __func__ << std::endl;
1161 if(karg.N % BBlockTransferSrcScalarPerVector != 0)
1165 std::cout <<
"Arg N (" << karg.N
1166 <<
") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1167 << BBlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1168 << __LINE__ <<
", in function: " << __func__ << std::endl;
1175 if(karg.K % BBlockTransferSrcScalarPerVector != 0)
1179 std::cout <<
"Arg K (" << karg.K
1180 <<
") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1181 << BBlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1182 << __LINE__ <<
", in function: " << __func__ << std::endl;
1190 if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
1194 std::cout <<
"Arg N (" << karg.N
1195 <<
") value is not a multiple of "
1196 "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1197 << CShuffleBlockTransferScalarPerVector_NPerBlock <<
" )! "
1198 << __FILE__ <<
":" << __LINE__ <<
", in function: " << __func__
1206 if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
1210 std::cout <<
"Arg M (" << karg.M
1211 <<
") value is not a multiple of "
1212 "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1213 << CShuffleBlockTransferScalarPerVector_NPerBlock <<
" )! "
1214 << __FILE__ <<
":" << __LINE__ <<
", in function: " << __func__
1226 if(!karg.IsReduceAdd())
1230 std::cout <<
" KBatch: " << karg.KBatch <<
" > 1 is not support yet" << __FILE__
1231 <<
":" << __LINE__ <<
", in function: " << __func__ << std::endl;
1241 const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value);
1245 if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
1257 const index_t num_loop = K / KPerBlock;
1259 return BlockwiseGemmPipe::BlockHasHotloop(num_loop);
1264 const index_t num_loop = K / KPerBlock;
1266 return BlockwiseGemmPipe::BlockLoopTailNum(num_loop);
1269 template <
typename CGr
idDesc>
1271 const CGridDesc& c_grid_desc_m_n,
index_t MBlock,
index_t NBlock)
1280 return c_grid_desc_mblock_mperblock_nblock_nperblock;
1288 template <
typename AGridDesc_AK0_M_K1,
1289 typename BGridDesc_BK0_N_K1,
1290 typename BScaleGridDesc_BN_AK,
1291 typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
1292 bool HasMainKBlockLoop,
1295 __device__
static void Run(
const ADataType* p_a_grid,
1296 const BDataType* p_b_grid,
1297 CDataType* p_c_grid,
1300 const Problem& problem,
1301 const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
1302 const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
1303 const BScaleGridDesc_BN_AK& b_scale_grid_desc_bn_ak,
1304 const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
1305 c_grid_desc_mblock_mperblock_nblock_nperblock)
1308 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1310 p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
1312 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1316 p_b_scale_grid, b_scale_grid_desc_bn_ak.GetElementSpaceSize());
1318 const AElementwiseOperation a_element_op{};
1319 const BElementwiseOperation b_element_op{};
1320 const CElementwiseOperation c_element_op{};
1323 const auto block_2_ctile_map =
Block2CTileMap{problem.M, problem.N, 4};
1325 const auto block_work_idx =
1328 if(!block_2_ctile_map.ValidCTileIndex(
1330 make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(
I0),
1331 c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(
I2))))
1336 const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[
I0]);
1337 const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[
I1]);
1340 const index_t m_block_data_idx_on_grid =
1341 __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
1343 const index_t n_block_data_idx_on_grid =
1344 __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
1356 auto a_blockwise_copy =
1358 AElementwiseOperation,
1362 ABlockTransferThreadClusterLengths_AK0_M_AK1,
1363 ABlockTransferThreadClusterArrangeOrder,
1366 decltype(a_grid_desc_ak0_m_ak1),
1367 decltype(a_block_desc_ak0_m_ak1),
1368 ABlockTransferSrcAccessOrder,
1370 ABlockTransferSrcVectorDim,
1372 ABlockTransferSrcScalarPerVector,
1373 ABlockTransferDstScalarPerVector_AK1,
1376 AThreadTransferSrcResetCoordinateAfterRun,
1378 BlockwiseGemmPipe::GlobalBufferNum>(
1379 a_grid_desc_ak0_m_ak1,
1382 a_block_desc_ak0_m_ak1,
1387 auto b_blockwise_copy =
1389 BElementwiseOperation,
1393 BBlockTransferThreadClusterLengths_BK0_N_BK1,
1394 BBlockTransferThreadClusterArrangeOrder,
1397 decltype(b_grid_desc_bk0_n_bk1),
1398 decltype(b_block_desc_bk0_n_bk1),
1399 BBlockTransferSrcAccessOrder,
1401 BBlockTransferSrcVectorDim,
1403 BBlockTransferSrcScalarPerVector,
1404 BBlockTransferDstScalarPerVector_BK1,
1407 BThreadTransferSrcResetCoordinateAfterRun,
1409 BlockwiseGemmPipe::GlobalBufferNum>(
1410 b_grid_desc_bk0_n_bk1,
1413 b_block_desc_bk0_n_bk1,
1419 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1423 static_cast<ADataType*
>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1426 reinterpret_cast<BDataType*
>(
static_cast<char*
>(p_shared) + a_block_space_size_aligned *
1429 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1435 static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
1437 auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
1439 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
1440 (a_grid_desc_ak0_m_ak1.GetLength(
I0) * a_grid_desc_ak0_m_ak1.GetLength(
I2)) /
1445 static constexpr auto mfma =
1448 static constexpr auto K1PerXdlops = mfma.GetK1PerXdlops();
1449 static constexpr auto K0PerXdlops = KPerXdlops / K1PerXdlops;
1450 static constexpr auto KPerThread = KPerBlock / K0PerXdlops;
1452 static constexpr auto ScaleSliceSizeN = NXdlPerWave;
1453 static constexpr auto ScaleSliceSizeK = (KPerThread + ScaleBlockK - 1) / ScaleBlockK;
1454 static constexpr auto KBlockScaleSliceSizeK = (KPerBlock + ScaleBlockK - 1) / ScaleBlockK;
1459 constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
1460 constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
1461 constexpr index_t WaveSize = BlockSize / (MWaves * NWaves);
1462#if defined(__gfx11__)
1471 auto b_scale_thread_copy =
1474 decltype(b_scale_grid_desc_bn_ak),
1475 decltype(b_scale_thread_desc),
1482 b_scale_grid_desc_bn_ak,
1484 b_thread_offset_k / ScaleBlockK));
1486 constexpr auto b_scale_thread_slice_copy_step =
1491 const index_t num_k_block_per_scale = (ScaleBlockK + KPerBlock - 1) / KPerBlock;
1494 a_grid_desc_ak0_m_ak1,
1495 a_block_desc_ak0_m_ak1,
1499 a_block_slice_copy_step,
1500 b_grid_desc_bk0_n_bk1,
1501 b_block_desc_bk0_n_bk1,
1505 b_block_slice_copy_step,
1507 b_scale_grid_desc_bn_ak,
1508 b_scale_thread_desc,
1509 b_scale_thread_copy,
1511 b_scale_thread_slice_copy_step,
1512 num_k_block_main_loop,
1513 num_k_block_per_scale);
1517 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
1518 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
1521 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1522 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
1525 constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
1526 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
1530 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
1531 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
1533 constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I0);
1534 constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I1);
1535 constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I2);
1536 constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I3);
1537 constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I4);
1538 constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I5);
1539 constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I6);
1540 constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I7);
1542 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1546 static_cast<CShuffleDataType*
>(p_shared),
1547 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1550 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1570 const auto c_thread_mtx_on_block =
1571 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(
I0,
I0,
I0,
I0);
1573 const index_t m_thread_data_on_block = c_thread_mtx_on_block[
I0];
1574 const index_t n_thread_data_on_block = c_thread_mtx_on_block[
I1];
1576 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
1582 const auto m_thread_data_on_block_idx =
1583 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
1586 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
1592 const auto n_thread_data_on_block_idx =
1593 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
1597 auto c_thread_copy_vgpr_to_lds =
1600 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1601 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1603 Sequence<CShuffleMXdlPerWavePerShuffle,
1604 CShuffleNXdlPerWavePerShuffle,
1617 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1620 m_thread_data_on_block_idx[
I1],
1621 n_thread_data_on_block_idx[
I1],
1622 m_thread_data_on_block_idx[
I2],
1623 m_thread_data_on_block_idx[
I3],
1624 m_thread_data_on_block_idx[
I4],
1625 n_thread_data_on_block_idx[
I2]),
1631 CElementwiseOperation,
1632 CGlobalMemoryDataOperation,
1634 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1636 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>,
1637 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
1641 decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
1642 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
1645 CShuffleBlockTransferScalarPerVector_NPerBlock,
1648 {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1650 c_grid_desc_mblock_mperblock_nblock_nperblock,
1655 constexpr auto sfc_c_vgpr =
1658 Sequence<CShuffleMXdlPerWavePerShuffle,
1659 CShuffleNXdlPerWavePerShuffle,
1668 constexpr auto sfc_c_global =
1672 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1674 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
1676 constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
1678 static_assert(num_access == sfc_c_global.GetNumOfAccess(),
"wrong!");
1685 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1686 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
1688 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1689 c_shuffle_block_buf);
1695 c_shuffle_block_copy_lds_to_global.Run(
1696 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1697 c_shuffle_block_buf,
1698 c_grid_desc_mblock_mperblock_nblock_nperblock,
1701 if constexpr(access_id < num_access - 1)
1703 constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
1706 c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
1707 c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
1713 template <
bool HasMainKBlockLoop,
1716 __device__
static void Run(
const ADataType* p_a_grid,
1717 const BDataType* p_b_grid,
1718 CDataType* p_c_grid,
1721 const Problem& problem)
1724 problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
1726 problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
1728 problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
1729 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
1731 c_grid_desc_m_n, problem.MBlock, problem.NBlock);
1739 Run<
decltype(a_grid_desc_ak0_m_ak1),
1740 decltype(b_grid_desc_bk0_n_bk1),
1741 decltype(b_scale_grid_desc_bn_ak),
1742 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
1744 CGlobalMemoryDataOperation,
1751 a_grid_desc_ak0_m_ak1,
1752 b_grid_desc_bk0_n_bk1,
1753 b_scale_grid_desc_bn_ak,
1754 c_grid_desc_mblock_mperblock_nblock_nperblock);
1757 template <
typename AGridDesc_AK0_M_K1,
1758 typename BGridDesc_BK0_N_K1,
1759 typename BScaleGridDesc_BN_AK,
1760 typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
1761 bool HasMainKBlockLoop,
1764 __device__
static void Run_2Lds(
const ADataType* p_a_grid,
1765 const BDataType* p_b_grid,
1766 CDataType* p_c_grid,
1770 const Problem& problem,
1771 const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
1772 const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
1773 const BScaleGridDesc_BN_AK& b_scale_grid_desc_bn_ak,
1774 const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
1775 c_grid_desc_mblock_mperblock_nblock_nperblock)
1778 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1780 p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
1782 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1786 p_b_scale_grid, b_scale_grid_desc_bn_ak.GetElementSpaceSize());
1788 const AElementwiseOperation a_element_op{};
1789 const BElementwiseOperation b_element_op{};
1790 const CElementwiseOperation c_element_op{};
1793 const auto block_2_ctile_map =
Block2CTileMap{problem.M, problem.N, 4};
1795 const auto block_work_idx =
1798 if(!block_2_ctile_map.ValidCTileIndex(
1800 make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(
I0),
1801 c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(
I2))))
1806 const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[
I0]);
1807 const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[
I1]);
1810 const index_t m_block_data_idx_on_grid =
1811 __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
1813 const index_t n_block_data_idx_on_grid =
1814 __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
1826 auto a_blockwise_copy =
1828 AElementwiseOperation,
1832 ABlockTransferThreadClusterLengths_AK0_M_AK1,
1833 ABlockTransferThreadClusterArrangeOrder,
1836 decltype(a_grid_desc_ak0_m_ak1),
1837 decltype(a_block_desc_ak0_m_ak1),
1838 ABlockTransferSrcAccessOrder,
1840 ABlockTransferSrcVectorDim,
1842 ABlockTransferSrcScalarPerVector,
1843 ABlockTransferDstScalarPerVector_AK1,
1846 AThreadTransferSrcResetCoordinateAfterRun,
1848 BlockwiseGemmPipe::GlobalBufferNum>(
1849 a_grid_desc_ak0_m_ak1,
1852 a_block_desc_ak0_m_ak1,
1857 auto b_blockwise_copy =
1859 BElementwiseOperation,
1863 BBlockTransferThreadClusterLengths_BK0_N_BK1,
1864 BBlockTransferThreadClusterArrangeOrder,
1867 decltype(b_grid_desc_bk0_n_bk1),
1868 decltype(b_block_desc_bk0_n_bk1),
1869 BBlockTransferSrcAccessOrder,
1871 BBlockTransferSrcVectorDim,
1873 BBlockTransferSrcScalarPerVector,
1874 BBlockTransferDstScalarPerVector_BK1,
1877 BThreadTransferSrcResetCoordinateAfterRun,
1879 BlockwiseGemmPipe::GlobalBufferNum>(
1880 b_grid_desc_bk0_n_bk1,
1883 b_block_desc_bk0_n_bk1,
1889 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1892 static_cast<ADataType*
>(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1896 a_block_space_size_aligned *
sizeof(ADataType) /
APackedSize),
1897 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1900 static_cast<ADataType*
>(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1904 a_block_space_size_aligned *
sizeof(ADataType) /
APackedSize),
1905 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1907 auto a_block_bufs =
make_tuple(a_block_buf_ping, a_block_buf_pong);
1908 auto b_block_bufs =
make_tuple(b_block_buf_ping, b_block_buf_pong);
1914 static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
1916 auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
1918 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
1919 (a_grid_desc_ak0_m_ak1.GetLength(
I0) * a_grid_desc_ak0_m_ak1.GetLength(
I2)) /
1923 static constexpr auto mfma =
1926 static constexpr auto K1PerXdlops = mfma.GetK1PerXdlops();
1927 static constexpr auto K0PerXdlops = KPerXdlops / K1PerXdlops;
1928 static constexpr auto KPerThread = KPerBlock / K0PerXdlops;
1930 const index_t ScaleSliceSizeN = NXdlPerWave;
1931 static constexpr auto ScaleSliceSizeK = (KPerThread + ScaleBlockK - 1) / ScaleBlockK;
1932 static constexpr auto KBlockScaleSliceSizeK = (KPerBlock + ScaleBlockK - 1) / ScaleBlockK;
1937 constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
1938 constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
1939 constexpr index_t WaveSize = BlockSize / (MWaves * NWaves);
1940#if defined(__gfx11__)
1950 auto b_scale_thread_copy =
1953 decltype(b_scale_grid_desc_bn_ak),
1954 decltype(b_scale_thread_desc),
1961 b_scale_grid_desc_bn_ak,
1963 b_thread_offset_k / ScaleBlockK));
1965 constexpr auto b_scale_thread_slice_copy_step =
1970 const index_t num_k_block_per_scale = (ScaleBlockK + KPerBlock - 1) / KPerBlock;
1973 a_grid_desc_ak0_m_ak1,
1974 a_block_desc_ak0_m_ak1,
1978 a_block_slice_copy_step,
1979 b_grid_desc_bk0_n_bk1,
1980 b_block_desc_bk0_n_bk1,
1984 b_block_slice_copy_step,
1987 b_scale_grid_desc_bn_ak,
1988 b_scale_thread_desc,
1989 b_scale_thread_copy,
1991 b_scale_thread_slice_copy_step,
1993 num_k_block_main_loop,
1994 num_k_block_per_scale);
1998 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
1999 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
2002 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
2003 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
2006 constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
2007 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
2011 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
2012 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
2014 constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I0);
2015 constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I1);
2016 constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I2);
2017 constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I3);
2018 constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I4);
2019 constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I5);
2020 constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I6);
2021 constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I7);
2023 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
2027 static_cast<CShuffleDataType*
>(p_shared_0),
2028 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2031 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2051 const auto c_thread_mtx_on_block =
2052 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(
I0,
I0,
I0,
I0);
2054 const index_t m_thread_data_on_block = c_thread_mtx_on_block[
I0];
2055 const index_t n_thread_data_on_block = c_thread_mtx_on_block[
I1];
2057 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
2063 const auto m_thread_data_on_block_idx =
2064 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
2067 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
2073 const auto n_thread_data_on_block_idx =
2074 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
2078 auto c_thread_copy_vgpr_to_lds =
2081 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2082 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2084 Sequence<CShuffleMXdlPerWavePerShuffle,
2085 CShuffleNXdlPerWavePerShuffle,
2098 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2101 m_thread_data_on_block_idx[
I1],
2102 n_thread_data_on_block_idx[
I1],
2103 m_thread_data_on_block_idx[
I2],
2104 m_thread_data_on_block_idx[
I3],
2105 m_thread_data_on_block_idx[
I4],
2106 n_thread_data_on_block_idx[
I2]),
2112 CElementwiseOperation,
2113 CGlobalMemoryDataOperation,
2115 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2117 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>,
2118 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
2122 decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
2123 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
2126 CShuffleBlockTransferScalarPerVector_NPerBlock,
2129 {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2131 c_grid_desc_mblock_mperblock_nblock_nperblock,
2136 constexpr auto sfc_c_vgpr =
2139 Sequence<CShuffleMXdlPerWavePerShuffle,
2140 CShuffleNXdlPerWavePerShuffle,
2149 constexpr auto sfc_c_global =
2153 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2155 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
2157 constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
2159 static_assert(num_access == sfc_c_global.GetNumOfAccess(),
"wrong!");
2166 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2167 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
2169 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2170 c_shuffle_block_buf);
2176 c_shuffle_block_copy_lds_to_global.Run(
2177 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2178 c_shuffle_block_buf,
2179 c_grid_desc_mblock_mperblock_nblock_nperblock,
2182 if constexpr(access_id < num_access - 1)
2184 constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
2187 c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
2188 c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
2194 template <
bool HasMainKBlockLoop,
2197 __device__
static void Run_2Lds(
const ADataType* p_a_grid,
2198 const BDataType* p_b_grid,
2199 CDataType* p_c_grid,
2203 const Problem& problem)
2206 problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
2208 problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
2210 problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
2212 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
2214 c_grid_desc_m_n, problem.MBlock, problem.NBlock);
2221 Run_2Lds<
decltype(a_grid_desc_ak0_m_ak1),
2222 decltype(b_grid_desc_bk0_n_bk1),
2223 decltype(b_scale_grid_desc_bn_ak),
2224 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
2226 CGlobalMemoryDataOperation,
2234 a_grid_desc_ak0_m_ak1,
2235 b_grid_desc_bk0_n_bk1,
2236 b_scale_grid_desc_bn_ak,
2237 c_grid_desc_mblock_mperblock_nblock_nperblock);
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
#define IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType_)
Definition device_base.hpp:178
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
Definition utility/math.hpp:78
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
__host__ __device__ constexpr auto lcm(X x, Y y)
Definition utility/math.hpp:198
GemmSpecialization
Definition gemm_specialization.hpp:11
@ MKPadding
Definition gemm_specialization.hpp:18
@ KPadding
Definition gemm_specialization.hpp:16
@ NPadding
Definition gemm_specialization.hpp:15
@ MPadding
Definition gemm_specialization.hpp:14
@ MNKPadding
Definition gemm_specialization.hpp:20
@ MNPadding
Definition gemm_specialization.hpp:17
@ NKPadding
Definition gemm_specialization.hpp:19
ushort bhalf_t
Definition data_type.hpp:30
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition ck.hpp:277
@ Set
Definition ck.hpp:278
__host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition tensor_description/tensor_adaptor.hpp:425
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
BlockGemmPipelineVersion
Definition blkgemmpipe_scheduler.hpp:12
@ v4
Definition blkgemmpipe_scheduler.hpp:17
@ v1
Definition blkgemmpipe_scheduler.hpp:14
__host__ __device__ constexpr auto make_freeze_transform(const LowerIndex &low_idx)
Definition multi_index_transform_helper.hpp:151
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
_Float16 half_t
Definition data_type.hpp:31
__host__ __device__ constexpr auto make_xor_with_modulo_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:185
integral_constant< index_t, N > Number
Definition number.hpp:12
TailNumber
Definition blkgemmpipe_scheduler.hpp:31
@ Odd
Definition blkgemmpipe_scheduler.hpp:33
@ Full
Definition blkgemmpipe_scheduler.hpp:49
__global__ void kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg)
Definition gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:62
constexpr auto BlockGemmPipeline_Selector()
Definition blockwise_gemm_pipeline_wmma_selector.hpp:32
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
bool EnvIsEnabled(EnvVar)
Definition utility/env.hpp:140
constexpr bool is_same_v
Definition type.hpp:283
__host__ __device__ constexpr auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:84
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__global__ void kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg)
Definition gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:38
__device__ index_t get_thread_local_1d_id()
Definition get_id.hpp:41
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__host__ __device__ constexpr Y bit_cast(const X &x)
Definition type.hpp:306
__device__ void block_sync_lds()
Definition synchronization.hpp:16
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
__host__ __device__ constexpr auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition dynamic_buffer.hpp:472
const GenericPointer< typename T::ValueType > T2 value
Definition pointer.h:1697
signed int int32_t
Definition stdint.h:123
Definition block_to_ctile_map.hpp:271
__host__ static __device__ constexpr index_t CalculateGridSize(index_t M, index_t N)
Definition block_to_ctile_map.hpp:283
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:716
const BElementwiseOperation b_element_op
Definition gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:643
const BDataType * p_b_grid
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:759
CDataType * p_c_grid
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:760
__host__ __device__ bool IsReduceAdd() const
Definition gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:627
const AElementwiseOperation a_element_op
Definition gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:642
__host__ __device__ bool IsAtomicAdd() const
Definition gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:632
const BScaleType * p_b_scale_grid
Definition gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:641
const ADataType * p_a_grid
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:758
__host__ Argument(const ADataType *p_a_grid_, const BDataType *p_b_grid_, CDataType *p_c_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_, index_t StrideScaleB_, const BScaleType *p_b_scale_grid_, index_t k_batch_, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, CElementwiseOperation c_element_op_, bool is_reduce_=false)
Definition gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:599
const CElementwiseOperation c_element_op
Definition gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:644
bool is_reduce
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:761
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:641
index_t N
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:695
index_t NPadded
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:702
index_t KBatch
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:700
__host__ Problem(index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_, index_t StrideScaleB_, index_t KBatch_)
Definition gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:541
index_t StrideA
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:697
__host__ Problem(index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_, index_t KBatch_, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:642
CElementwiseOperation c_element_op_
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:711
index_t BK0
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:706
index_t M
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:694
index_t NBlock
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:708
index_t MPadded
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:701
index_t K
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:696
index_t StrideB
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:698
index_t KPadded
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:704
index_t StrideC
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:699
index_t MBlock
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:707
BElementwiseOperation b_element_op_
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:710
index_t AK0
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:705
index_t KRead
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:703
AElementwiseOperation a_element_op_
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:709
index_t StrideScaleB
Definition gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:584
__host__ void Print() const
Definition gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:568
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:765
index_t a_k_split_offset
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:814
__device__ SplitKBatchOffset(Argument &karg, index_t k_id)
Definition gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:651
index_t b_k_split_offset
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:815
index_t scale_k_split_offset
Definition gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:710
index_t c_reduce_offset
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:816
"Universal" GEMM kernel with SplitK support.
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:247
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::MakeBGridDescriptor_BK0_N_BK1 __host__ static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
Definition gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:350
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::CalculateKRead static __host__ auto CalculateKRead(index_t K, index_t K_Batch=1)
Definition gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:236
static constexpr auto is_scale_mfma
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:273
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::MakeGemmMmaTileDescriptor __host__ static __device__ constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1 &)
Definition gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:254
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::CalculateKPadded static __host__ auto CalculateKPadded(index_t K)
Definition gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:213
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::CalculateMPadded static __host__ auto CalculateMPadded(index_t M)
Definition gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:203
static constexpr auto BK1Number
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:261
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >< math::max(NXdlPerWave64, 1)>::APackedSize static constexpr index_t APackedSize
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:285
static constexpr bool is_single_rate_mfma
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:264
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::GetSharedMemoryNumberOfByte static __device__ constexpr index_t GetSharedMemoryNumberOfByte()
Definition gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:1029
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >< math::max(NXdlPerWave64, 1)>::BlockwiseGemmPipe remove_cvref_t< decltype(BlockGemmPipeline_Selector< BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, ADataType, BDataType, BlkGemmPipeSched, GemmAccDataType, decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()), decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()), decltype(MakeAMmaTileDescriptor_M0_M1_M2_K(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), decltype(MakeBMmaTileDescriptor_N0_N1_N2_K(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXdl, NPerXdl, MXdlPerWave, NXdlPerWave, KPack >())> BlockwiseGemmPipe
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:1112
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1 static __device__ constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:714
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >< math::max(NXdlPerWave64, 1)>::ThisThreadBlock ThisThreadBlock< BlockSize > ThisThreadBlock
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:283
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::CalculateAK0Padded static __host__ auto CalculateAK0Padded(index_t K, index_t K_Batch=1)
Definition gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:218
static constexpr auto I2
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:250
static constexpr index_t KPack
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:274
static constexpr auto lcm_AK1_BK1
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:263
static constexpr auto I7
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:255
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::CalculateKBlockLoopTailNum static __host__ constexpr TailNumber CalculateKBlockLoopTailNum(index_t K)
Definition gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:1262
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::CalculateHasMainKBlockLoop static __host__ constexpr bool CalculateHasMainKBlockLoop(index_t K)
Definition gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:1255
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::Run_2Lds static __device__ void Run_2Lds(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, const BScaleType *p_b_scale_grid, void *p_shared_0, void *p_shared_1, const Problem &problem, const AGridDesc_AK0_M_K1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_K1 &b_grid_desc_bk0_n_bk1, const BScaleGridDesc_BN_AK &b_scale_grid_desc_bn_ak, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &c_grid_desc_mblock_mperblock_nblock_nperblock)
Definition gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:1764
static constexpr auto I5
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:253
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::Run static __device__ void Run(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, const BScaleType *p_b_scale_grid, void *p_shared, const Problem &problem, const AGridDesc_AK0_M_K1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_K1 &b_grid_desc_bk0_n_bk1, const BScaleGridDesc_BN_AK &b_scale_grid_desc_bn_ak, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &c_grid_desc_mblock_mperblock_nblock_nperblock)
Definition gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:1295
static constexpr auto AK1Number
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:260
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock static __device__ constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
Definition gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:988
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >< math::max(NXdlPerWave64, 1)>::Run_2Lds static __device__ void Run_2Lds(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, void *p_shared_0, void *p_shared_1, const Problem &problem, const AGridDesc_AK0_M_K1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_K1 &b_grid_desc_bk0_n_bk1, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &c_grid_desc_mblock_mperblock_nblock_nperblock)
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:1853
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::CalculateGridSize static __host__ auto CalculateGridSize(index_t M, index_t N, index_t KBatch)
Definition gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:198
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::CalculateMBlock static __host__ auto CalculateMBlock(index_t M)
Definition gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:243
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::MakeAMmaTileDescriptor_M0_M1_M2_K __host__ static __device__ constexpr auto MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1 &)
Definition gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:462
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock __host__ static __device__ constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:1270
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::CalculateNPadded static __host__ auto CalculateNPadded(index_t N)
Definition gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:208
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1 static __device__ constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:853
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::CalculateBK0Padded static __host__ auto CalculateBK0Padded(index_t K, index_t K_Batch=1)
Definition gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:224
static constexpr index_t BPackedSize
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:292
static __device__ void Run(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, void *p_shared, const Problem &problem, const AGridDesc_AK0_M_K1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_K1 &b_grid_desc_bk0_n_bk1, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &c_grid_desc_mblock_mperblock_nblock_nperblock)
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:1437
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::Run_2Lds static __device__ void Run_2Lds(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, const BScaleType *p_b_scale_grid, void *p_shared_0, void *p_shared_1, const Problem &problem)
Definition gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:2197
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::MakeBMmaTileDescriptor_N0_N1_N2_K __host__ static __device__ constexpr auto MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1 &)
Definition gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:471
static constexpr auto I6
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:254
static constexpr auto I1
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:249
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >< math::max(NXdlPerWave64, 1)>::I0 static constexpr auto I0
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:248
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::Run static __device__ void Run(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, const BScaleType *p_b_scale_grid, void *p_shared, const Problem &problem)
Definition gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:1716
static constexpr auto I3
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:251
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >< math::max(NXdlPerWave64, 1)>::CheckValidity static __host__ constexpr bool CheckValidity(const Argument &karg)
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:1202
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >< math::max(NXdlPerWave64, 1)>::Block2CTileMap BlockToCTileMap_Grouped_M00_N0_M01Adapt< 8, MPerBlock, NPerBlock > Block2CTileMap
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:1428
static constexpr auto I4
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:252
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::CalculateKPadded static __host__ auto CalculateKPadded(index_t K, index_t K_Batch=1)
Definition gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:230
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >< math::max(NXdlPerWave64, 1)>::BScaleType ck::half_t BScaleType
Definition gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:146
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::MakeAGridDescriptor_AK0_M_AK1 __host__ static __device__ auto MakeAGridDescriptor_AK0_M_AK1(index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
Definition gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:268
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::CalculateNBlock static __host__ auto CalculateNBlock(index_t N)
Definition gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:248
static constexpr auto BK0Number
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:259
static constexpr auto AK0Number
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:258
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::MakeCGridDescriptor_M_N __host__ static __device__ auto MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
Definition gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:479
Selects the appropriate MFMA instruction type and configuration for given data types and tile sizes o...
Definition xdlops_gemm.hpp:1208
static constexpr index_t GetKPerXdlops()
Definition xdlops_gemm.hpp:1804
Definition utility/sequence.hpp:43
Definition tensor_space_filling_curve.hpp:20
Blockwise data transfer.
Definition thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition thread_group_tensor_slice_transfer_v6r1.hpp:34
Definition threadwise_tensor_slice_transfer.hpp:39
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition threadwise_tensor_slice_transfer.hpp:234
static constexpr value_type value
Definition utility/integral_constant.hpp:13
Definition data_type.hpp:187
Definition functional2.hpp:33
Definition device_base.hpp:197
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340
#define CK_ENV(name)
Definition utility/env.hpp:129