device_gemm_xdl.hpp Source File#
device_gemm_xdl.hpp
Go to the documentation of this file.
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
Definition convolution_backward_data_specialization.hpp:8
GemmSpecialization
Definition gemm_specialization.hpp:11
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
bool is_lds_direct_load_supported()
Definition host_utility/device_prop.hpp:101
__global__ void kernel_gemm_xdlops_v2r3(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, const CGridDesc_M_N c_grid_desc_m_n)
Definition gridwise_gemm_xdlops_v2r3.hpp:34
constexpr LoopScheduler make_default_loop_scheduler()
Definition loop_scheduler.hpp:20
Definition ck/stream_config.hpp:10
Definition gridwise_gemm_xdlops_v2r3.hpp:240
Definition gridwise_gemm_xdlops_v2r3.hpp:814
static __host__ constexpr bool CheckValidity(const Problem &problem)
Definition gridwise_gemm_xdlops_v2r3.hpp:1003
Definition utility/sequence.hpp:43
Definition device_base.hpp:197
BaseInvoker()=default
Definition device_gemm.hpp:22
Definition device_gemm_xdl.hpp:134
INVOKER_RUN3_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_gemm_xdl.hpp:175
float RunImp(const typename GridwiseGemm::Argument &karg, const StreamConfig &stream_config=StreamConfig{})
Definition device_gemm_xdl.hpp:136
Definition device_gemm_xdl.hpp:71
static bool IsSupportedArgument(const Argument &karg)
Definition device_gemm_xdl.hpp:188
static constexpr auto K1Number
Definition device_gemm_xdl.hpp:80
GridwiseGemmBase< NXdlPerWave32 > GridwiseGemm32
Definition device_gemm_xdl.hpp:128
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_gemm_xdl.hpp:73
static constexpr auto I0
Definition device_gemm_xdl.hpp:76
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_gemm_xdl.hpp:234
typename GridwiseGemm64::Argument Argument
Definition device_gemm_xdl.hpp:130
static auto MakeInvoker()
Definition device_gemm_xdl.hpp:255
std::string GetTypeString() const override
Definition device_gemm_xdl.hpp:289
static constexpr auto I2
Definition device_gemm_xdl.hpp:78
static constexpr auto I1
Definition device_gemm_xdl.hpp:77
static auto MakeArgument(const ADataType *p_a, const BDataType *p_b, CDataType *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation)
Definition device_gemm_xdl.hpp:239
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext< BlockSize, ADataType, AccDataType, CDataType, InMemoryDataOperationEnum::Set, ALayout, BLayout, CLayout, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, MPerBlock, NPerBlock, K0PerBlock, MPerXDL, NPerXDL, K1, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, ABlockLdsAddExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, BBlockLdsAddExtraN, Sequence< 0, 2, 4, 5, 6, 1, 3, 7 >, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector, NumPrefetch, LoopSched, PipelineVer > GridwiseGemmBase
Definition device_gemm_xdl.hpp:84
static constexpr bool IsValidCompilationParameter()
Definition device_gemm_xdl.hpp:182
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_gemm_xdl.hpp:283
static constexpr auto NXdlPerWave32
Definition device_gemm_xdl.hpp:74
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_gemm_xdl.hpp:127
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation) override
Definition device_gemm_xdl.hpp:258