21template <
typename GridwiseGemm,
25 typename AElementwiseOperation,
26 typename BElementwiseOperation,
27 typename CDEElementwiseOperation,
28 typename AGridDesc_K0_M0_M1_K1,
29 typename BGridDesc_K0_N0_N1_K1,
30 typename DsGridDesc_M0_M10_M11_N0_N10_N11,
31 typename CGridDesc_M0_M10_M11_N0_N10_N11,
32 typename Block2CTileMap,
33 bool HasMainKBlockLoop,
34 bool HasDoubleTailKBlockLoop>
36#if CK_USE_LAUNCH_BOUNDS
40 const ABDataType* __restrict__ p_a_grid,
41 const ABDataType* __restrict__ p_b_grid,
43 EDataType* __restrict__ p_e_grid,
44 const AElementwiseOperation a_element_op,
45 const BElementwiseOperation b_element_op,
46 const CDEElementwiseOperation cde_element_op,
47 const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1,
48 const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1,
49 const DsGridDesc_M0_M10_M11_N0_N10_N11 ds_grid_desc_m0_m10_m11_n0_n10_n11,
50 const CGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11,
51 const Block2CTileMap block_2_ctile_map)
53#if(defined(__gfx906__) || defined(__gfx9__) || defined(__gfx103__) || defined(__gfx11__) || \
56 constexpr index_t shared_block_size =
57 GridwiseGemm::GetSharedMemoryNumberOfByte() /
sizeof(ABDataType);
59 __shared__ ABDataType p_shared[shared_block_size];
61 GridwiseGemm::Run(p_a_grid,
69 a_grid_desc_k0_m0_m1_k1,
70 b_grid_desc_k0_n0_n1_k1,
71 ds_grid_desc_m0_m10_m11_n0_n10_n11,
72 e_grid_desc_m0_m10_m11_n0_n10_n11,
84 ignore = a_grid_desc_k0_m0_m1_k1;
85 ignore = b_grid_desc_k0_n0_n1_k1;
86 ignore = ds_grid_desc_m0_m10_m11_n0_n10_n11;
87 ignore = e_grid_desc_m0_m10_m11_n0_n10_n11;
88 ignore = block_2_ctile_map;
97template <
typename ALayout,
103 typename AccDataType,
106 typename AElementwiseOperation,
107 typename BElementwiseOperation,
108 typename CDEElementwiseOperation,
118 typename M1N1ThreadClusterM1Xs,
119 typename M1N1ThreadClusterN1Xs,
120 typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
121 typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
122 typename ABlockTransferThreadClusterArrangeOrder,
123 typename ABlockTransferSrcAccessOrder,
124 typename ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
125 typename ABlockTransferSrcVectorTensorContiguousDimOrder,
126 typename ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
127 typename BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
128 typename BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
129 typename BBlockTransferThreadClusterArrangeOrder,
130 typename BBlockTransferSrcAccessOrder,
131 typename BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
132 typename BBlockTransferSrcVectorTensorContiguousDimOrder,
133 typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
134 typename CThreadTransferSrcDstAccessOrder,
135 index_t CThreadTransferSrcDstVectorDim,
136 index_t CThreadTransferDstScalarPerVector,
149 AElementwiseOperation,
150 BElementwiseOperation,
151 CDEElementwiseOperation>
172 const auto a_grid_desc_m_k = [&]() {
185 const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
211 const auto b_grid_desc_k_n = [&]() {
224 const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
244 template <
typename ELay>
247 const auto c_grid_desc_m_n = [&]() {
260 const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
261 const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
281 const std::array<index_t, NumDTensor>& NRaws,
282 const std::array<index_t, NumDTensor>& DsStride)
305 AElementwiseOperation,
306 BElementwiseOperation,
307 CDEElementwiseOperation,
319 M1N1ThreadClusterM1Xs,
320 M1N1ThreadClusterN1Xs,
321 ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
322 ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
323 ABlockTransferThreadClusterArrangeOrder,
324 ABlockTransferSrcAccessOrder,
325 ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
326 ABlockTransferSrcVectorTensorContiguousDimOrder,
327 ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
328 BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
329 BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
330 BBlockTransferThreadClusterArrangeOrder,
331 BBlockTransferSrcAccessOrder,
332 BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
333 BBlockTransferSrcVectorTensorContiguousDimOrder,
334 BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
335 CThreadTransferSrcDstAccessOrder,
336 CThreadTransferSrcDstVectorDim,
337 CThreadTransferDstScalarPerVector>;
354 const void* p_b_grid,
355 std::array<const void*, NumDTensor> p_ds_grid,
362 std::array<index_t, NumDTensor> StrideDs,
364 AElementwiseOperation a_element_op,
365 BElementwiseOperation b_element_op,
366 CDEElementwiseOperation cde_element_op)
367 :
p_a_grid_{static_cast<const ADataType*>(p_a_grid)},
368 p_b_grid_{static_cast<const BDataType*>(p_b_grid)},
370 p_e_grid_{static_cast<EDataType*>(p_e_grid)},
388 p_ds_grid_(i) =
static_cast<const DDataType*
>(p_ds_grid[i]);
447 std::cout <<
"arg.a_grid_desc_k0_m0_m1_k1_{"
452 std::cout <<
"arg.b_grid_desc_k0_n0_n1_k1_{"
464 throw std::runtime_error(
465 "wrong! GridwiseGemmDlMultipleD_km_kn_mn has invalid setting");
471 auto launch_kernel = [&](
auto has_main_k_block_loop,
472 auto has_double_tail_k_block_loop) {
473 constexpr bool has_main_loop = has_main_k_block_loop.value;
474 constexpr bool has_double_loop = has_double_tail_k_block_loop.value;
481 AElementwiseOperation,
482 BElementwiseOperation,
483 CDEElementwiseOperation,
513 const bool has_double_tail_k_block_loop =
516 if(has_main_k_block_loop && has_double_tail_k_block_loop)
519 integral_constant<bool, true>{});
521 else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
524 integral_constant<bool, false>{});
526 else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
529 integral_constant<bool, true>{});
534 integral_constant<bool, false>{});
542 return Run(*
dynamic_cast<const Argument*
>(p_arg), stream_config);
574 std::array<const void*, NumDTensor> p_ds,
581 std::array<ck::index_t, NumDTensor> StrideDs,
583 AElementwiseOperation a_element_op,
584 BElementwiseOperation b_element_op,
585 CDEElementwiseOperation cde_element_op)
606 std::unique_ptr<BaseArgument>
609 std::array<const void*, NumDTensor> p_ds,
616 std::array<ck::index_t, NumDTensor> StrideDs,
618 AElementwiseOperation a_element_op,
619 BElementwiseOperation b_element_op,
620 CDEElementwiseOperation cde_element_op)
override
622 return std::make_unique<Argument>(p_a,
641 return std::make_unique<Invoker>(
Invoker{});
647 auto str = std::stringstream();
650 str <<
"DeviceGemmMultipleD_Dl"
655 << K0PerBlock <<
", "
657 << M1PerThread <<
", "
658 << N1PerThread <<
", "
#define CK_MIN_BLOCK_PER_CU
Definition ck.hpp:31
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
Definition convolution_backward_data_specialization.hpp:8
GemmSpecialization
Definition gemm_specialization.hpp:11
@ MNPadding
Definition gemm_specialization.hpp:17
Definition convolution_backward_data_specialization.hpp:7
CK_TILE_HOST float launch_kernel(const stream_config &s, Callables &&... callables)
Definition tile/host/kernel_launch.hpp:173
bool is_xdl_supported()
Definition host_utility/device_prop.hpp:68
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
@ Set
Definition ck.hpp:278
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
integral_constant< index_t, N > Number
Definition number.hpp:12
std::string get_device_name()
Definition host_utility/device_prop.hpp:19
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
bool is_gfx12_supported()
Definition host_utility/device_prop.hpp:55
__global__ void kernel_gemm_dl_multiple_d(const ABDataType *__restrict__ p_a_grid, const ABDataType *__restrict__ p_b_grid, DsPointer p_ds_grid, EDataType *__restrict__ p_e_grid, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CDEElementwiseOperation cde_element_op, const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1, const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1, const DsGridDesc_M0_M10_M11_N0_N10_N11 ds_grid_desc_m0_m10_m11_n0_n10_n11, const CGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11, const Block2CTileMap block_2_ctile_map)
Definition device_gemm_multiple_d_dl.hpp:39
bool is_gfx103_supported()
Definition host_utility/device_prop.hpp:120
constexpr bool is_same_v
Definition type.hpp:283
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
bool is_gfx11_supported()
Definition host_utility/device_prop.hpp:60
typename std::enable_if< B, T >::type enable_if_t
Definition enable_if.hpp:27
Definition ck/stream_config.hpp:10
Definition gridwise_gemm_dl_multiple_d.hpp:60
ck::GridwiseGemmDlMultipleD_km_kn_mn< BlockSize, ADataType, AccDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_K0_M_K1, BGridDesc_K0_N_K1, EGridDesc_M_N, MPerBlock, NPerBlock, K0PerBlock, K1, M1PerThread, N1PerThread, KPerThread, M1N1ThreadClusterM1Xs, M1N1ThreadClusterN1Xs, ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, ABlockTransferSrcVectorTensorContiguousDimOrder, ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, BBlockTransferSrcVectorTensorContiguousDimOrder, BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector >::MakeDefaultBlock2CTileMap __host__ static __device__ constexpr auto MakeDefaultBlock2CTileMap(const EGridDesc_M_N &c_grid_desc_m_n)
Definition gridwise_gemm_dl_multiple_d.hpp:242
ck::GridwiseGemmDlMultipleD_km_kn_mn< BlockSize, ADataType, AccDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_K0_M_K1, BGridDesc_K0_N_K1, EGridDesc_M_N, MPerBlock, NPerBlock, K0PerBlock, K1, M1PerThread, N1PerThread, KPerThread, M1N1ThreadClusterM1Xs, M1N1ThreadClusterN1Xs, ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, ABlockTransferSrcVectorTensorContiguousDimOrder, ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, BBlockTransferSrcVectorTensorContiguousDimOrder, BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector >::MakeBGridDescriptor_K0_N0_N1_K1 __host__ static __device__ constexpr auto MakeBGridDescriptor_K0_N0_N1_K1(const BGridDesc_K0_N_K1 &b_grid_desc_k0_n_k1)
Definition gridwise_gemm_dl_multiple_d.hpp:178
ck::GridwiseGemmDlMultipleD_km_kn_mn< BlockSize, ADataType, AccDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_K0_M_K1, BGridDesc_K0_N_K1, EGridDesc_M_N, MPerBlock, NPerBlock, K0PerBlock, K1, M1PerThread, N1PerThread, KPerThread, M1N1ThreadClusterM1Xs, M1N1ThreadClusterN1Xs, ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, ABlockTransferSrcVectorTensorContiguousDimOrder, ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, BBlockTransferSrcVectorTensorContiguousDimOrder, BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector >::MakeAGridDescriptor_K0_M0_M1_K1 __host__ static __device__ constexpr auto MakeAGridDescriptor_K0_M0_M1_K1(const AGridDesc_K0_M_K1 &a_grid_desc_k0_m_k1)
Definition gridwise_gemm_dl_multiple_d.hpp:158
ck::GridwiseGemmDlMultipleD_km_kn_mn< BlockSize, ADataType, AccDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_K0_M_K1, BGridDesc_K0_N_K1, EGridDesc_M_N, MPerBlock, NPerBlock, K0PerBlock, K1, M1PerThread, N1PerThread, KPerThread, M1N1ThreadClusterM1Xs, M1N1ThreadClusterN1Xs, ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, ABlockTransferSrcVectorTensorContiguousDimOrder, ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, BBlockTransferSrcVectorTensorContiguousDimOrder, BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector >::DsGridPointer decltype(MakeDsGridPointer()) DsGridPointer
Definition gridwise_gemm_dl_multiple_d.hpp:253
ck::GridwiseGemmDlMultipleD_km_kn_mn< BlockSize, ADataType, AccDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_K0_M_K1, BGridDesc_K0_N_K1, EGridDesc_M_N, MPerBlock, NPerBlock, K0PerBlock, K1, M1PerThread, N1PerThread, KPerThread, M1N1ThreadClusterM1Xs, M1N1ThreadClusterN1Xs, ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, ABlockTransferSrcVectorTensorContiguousDimOrder, ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, BBlockTransferSrcVectorTensorContiguousDimOrder, BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector >::CalculateGridSize __host__ static __device__ constexpr index_t CalculateGridSize(index_t M, index_t N)
Definition gridwise_gemm_dl_multiple_d.hpp:136
ck::GridwiseGemmDlMultipleD_km_kn_mn< BlockSize, ADataType, AccDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_K0_M_K1, BGridDesc_K0_N_K1, EGridDesc_M_N, MPerBlock, NPerBlock, K0PerBlock, K1, M1PerThread, N1PerThread, KPerThread, M1N1ThreadClusterM1Xs, M1N1ThreadClusterN1Xs, ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, ABlockTransferSrcVectorTensorContiguousDimOrder, ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, BBlockTransferSrcVectorTensorContiguousDimOrder, BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector >::CalculateHasDoubleTailKBlockLoop __host__ static __device__ constexpr bool CalculateHasDoubleTailKBlockLoop(index_t K0)
Definition gridwise_gemm_dl_multiple_d.hpp:150
ck::GridwiseGemmDlMultipleD_km_kn_mn< BlockSize, ADataType, AccDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_K0_M_K1, BGridDesc_K0_N_K1, EGridDesc_M_N, MPerBlock, NPerBlock, K0PerBlock, K1, M1PerThread, N1PerThread, KPerThread, M1N1ThreadClusterM1Xs, M1N1ThreadClusterN1Xs, ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, ABlockTransferSrcVectorTensorContiguousDimOrder, ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, BBlockTransferSrcVectorTensorContiguousDimOrder, BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector >::CalculateHasMainKBlockLoop __host__ static __device__ constexpr bool CalculateHasMainKBlockLoop(index_t K0)
Definition gridwise_gemm_dl_multiple_d.hpp:143
ck::GridwiseGemmDlMultipleD_km_kn_mn< BlockSize, ADataType, AccDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_K0_M_K1, BGridDesc_K0_N_K1, EGridDesc_M_N, MPerBlock, NPerBlock, K0PerBlock, K1, M1PerThread, N1PerThread, KPerThread, M1N1ThreadClusterM1Xs, M1N1ThreadClusterN1Xs, ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, ABlockTransferSrcVectorTensorContiguousDimOrder, ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, BBlockTransferSrcVectorTensorContiguousDimOrder, BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector >::CheckValidity __host__ static __device__ constexpr bool CheckValidity(const AGridDesc_K0_M_K1 &a_grid_desc_k0_m_k1, const BGridDesc_K0_N_K1 &b_grid_desc_k0_n_k1, const EGridDesc_M_N &c_grid_desc_m_n)
Definition gridwise_gemm_dl_multiple_d.hpp:110
ck::GridwiseGemmDlMultipleD_km_kn_mn< BlockSize, ADataType, AccDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_K0_M_K1, BGridDesc_K0_N_K1, EGridDesc_M_N, MPerBlock, NPerBlock, K0PerBlock, K1, M1PerThread, N1PerThread, KPerThread, M1N1ThreadClusterM1Xs, M1N1ThreadClusterN1Xs, ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, ABlockTransferSrcVectorTensorContiguousDimOrder, ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, BBlockTransferSrcVectorTensorContiguousDimOrder, BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector >::MakeDsGridDescriptor_M0_M10_M11_N0_N10_N11 __host__ static __device__ constexpr auto MakeDsGridDescriptor_M0_M10_M11_N0_N10_N11(const DsGridDesc_M_N &ds_grid_desc_m_n)
Definition gridwise_gemm_dl_multiple_d.hpp:234
ck::GridwiseGemmDlMultipleD_km_kn_mn< BlockSize, ADataType, AccDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_K0_M_K1, BGridDesc_K0_N_K1, EGridDesc_M_N, MPerBlock, NPerBlock, K0PerBlock, K1, M1PerThread, N1PerThread, KPerThread, M1N1ThreadClusterM1Xs, M1N1ThreadClusterN1Xs, ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, ABlockTransferSrcVectorTensorContiguousDimOrder, ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, BBlockTransferSrcVectorTensorContiguousDimOrder, BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector >::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11 __host__ static __device__ constexpr auto MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(const CGridDesc_M_N_ &c_grid_desc_m_n)
Definition gridwise_gemm_dl_multiple_d.hpp:200
Definition utility/sequence.hpp:43
static constexpr value_type value
Definition utility/integral_constant.hpp:13
Definition utility/integral_constant.hpp:20
Definition functional2.hpp:33
Definition device_base.hpp:197
Definition device_gemm_multiple_d_dl.hpp:352
GridwiseGemm::DsGridPointer p_ds_grid_
Definition device_gemm_multiple_d_dl.hpp:418
const BDataType * p_b_grid_
Definition device_gemm_multiple_d_dl.hpp:417
BElementwiseOperation b_element_op_
Definition device_gemm_multiple_d_dl.hpp:435
BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1_
Definition device_gemm_multiple_d_dl.hpp:427
AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1_
Definition device_gemm_multiple_d_dl.hpp:426
CDEElementwiseOperation cde_element_op_
Definition device_gemm_multiple_d_dl.hpp:436
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_
Definition device_gemm_multiple_d_dl.hpp:422
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_
Definition device_gemm_multiple_d_dl.hpp:421
EDataType * p_e_grid_
Definition device_gemm_multiple_d_dl.hpp:419
Argument(const void *p_a_grid, const void *p_b_grid, std::array< const void *, NumDTensor > p_ds_grid, void *p_e_grid, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, std::array< index_t, NumDTensor > StrideDs, index_t StrideE, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)
Definition device_gemm_multiple_d_dl.hpp:353
DefaultBlock2CTileMap block_2_ctile_map_
Definition device_gemm_multiple_d_dl.hpp:431
EGridDesc_M_N e_grid_desc_m_n_
Definition device_gemm_multiple_d_dl.hpp:424
const ADataType * p_a_grid_
Definition device_gemm_multiple_d_dl.hpp:416
DsGridDesc_M0_M10_M11_N0_N10_N11 ds_grid_desc_m0_m10_m11_n0_n10_n11_
Definition device_gemm_multiple_d_dl.hpp:428
DsGridDesc_M_N ds_grid_desc_m_n_
Definition device_gemm_multiple_d_dl.hpp:423
AElementwiseOperation a_element_op_
Definition device_gemm_multiple_d_dl.hpp:434
EGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11_
Definition device_gemm_multiple_d_dl.hpp:429
Definition device_gemm_multiple_d_dl.hpp:441
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_gemm_multiple_d_dl.hpp:444
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_gemm_multiple_d_dl.hpp:539
DeviceGemmMultipleD_Dl::Argument Argument
Definition device_gemm_multiple_d_dl.hpp:442
Definition device_gemm_multiple_d_dl.hpp:153
decltype(GridwiseGemm::MakeDefaultBlock2CTileMap(EGridDesc_M_N{})) DefaultBlock2CTileMap
Definition device_gemm_multiple_d_dl.hpp:347
std::string GetTypeString() const override
Definition device_gemm_multiple_d_dl.hpp:645
decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1)) AGridDesc_K0_M_K1
Definition device_gemm_multiple_d_dl.hpp:293
static constexpr bool IsValidCompilationParameter()
Definition device_gemm_multiple_d_dl.hpp:546
static auto MakeDsGridDescriptor_M_N(const std::array< index_t, NumDTensor > &MRaws, const std::array< index_t, NumDTensor > &NRaws, const std::array< index_t, NumDTensor > &DsStride)
Definition device_gemm_multiple_d_dl.hpp:280
static constexpr auto I4
Definition device_gemm_multiple_d_dl.hpp:161
static constexpr index_t NumDTensor
Definition device_gemm_multiple_d_dl.hpp:155
static constexpr auto K1Number
Definition device_gemm_multiple_d_dl.hpp:164
decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1)) BGridDesc_K0_N_K1
Definition device_gemm_multiple_d_dl.hpp:294
static bool IsSupportedArgument(const Argument &arg)
Definition device_gemm_multiple_d_dl.hpp:552
static auto MakeArgument(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_e, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, std::array< ck::index_t, NumDTensor > StrideDs, index_t StrideE, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)
Definition device_gemm_multiple_d_dl.hpp:572
GridwiseGemmDlMultipleD_km_kn_mn< BlockSize, ADataType, AccDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_K0_M_K1, BGridDesc_K0_N_K1, EGridDesc_M_N, MPerBlock, NPerBlock, K0PerBlock, K1, M1PerThread, N1PerThread, KPerThread, M1N1ThreadClusterM1Xs, M1N1ThreadClusterN1Xs, ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, ABlockTransferSrcVectorTensorContiguousDimOrder, ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, BBlockTransferSrcVectorTensorContiguousDimOrder, BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector > GridwiseGemm
Definition device_gemm_multiple_d_dl.hpp:299
static auto MakeEGridDescriptor_M_N(index_t M, index_t N, index_t StrideE)
Definition device_gemm_multiple_d_dl.hpp:245
static constexpr auto I3
Definition device_gemm_multiple_d_dl.hpp:160
decltype(GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{})) AGridDesc_K0_M0_M1_K1
Definition device_gemm_multiple_d_dl.hpp:339
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_gemm_multiple_d_dl.hpp:639
static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB)
Definition device_gemm_multiple_d_dl.hpp:205
DeviceGemmMultipleD_Dl DeviceOp
Definition device_gemm_multiple_d_dl.hpp:154
decltype(GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{})) BGridDesc_K0_N0_N1_K1
Definition device_gemm_multiple_d_dl.hpp:341
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_gemm_multiple_d_dl.hpp:567
decltype(MakeDsGridDescriptor_M_N({}, {}, {})) DsGridDesc_M_N
Definition device_gemm_multiple_d_dl.hpp:295
static constexpr auto I0
Definition device_gemm_multiple_d_dl.hpp:157
static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA)
Definition device_gemm_multiple_d_dl.hpp:166
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_e, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, std::array< ck::index_t, NumDTensor > StrideDs, index_t StrideE, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op) override
Definition device_gemm_multiple_d_dl.hpp:607
static constexpr auto I2
Definition device_gemm_multiple_d_dl.hpp:159
decltype(GridwiseGemm::MakeDsGridDescriptor_M0_M10_M11_N0_N10_N11(DsGridDesc_M_N{})) DsGridDesc_M0_M10_M11_N0_N10_N11
Definition device_gemm_multiple_d_dl.hpp:343
decltype(MakeEGridDescriptor_M_N< ELayout >(1, 1, 1)) EGridDesc_M_N
Definition device_gemm_multiple_d_dl.hpp:296
decltype(GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(EGridDesc_M_N{})) EGridDesc_M0_M10_M11_N0_N10_N11
Definition device_gemm_multiple_d_dl.hpp:345
static constexpr auto I1
Definition device_gemm_multiple_d_dl.hpp:158
static auto MakeInvoker()
Definition device_gemm_multiple_d_dl.hpp:603
static constexpr auto I5
Definition device_gemm_multiple_d_dl.hpp:162
Definition device_gemm_multiple_d.hpp:36