device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp File Reference

device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp File Reference#

Composable Kernel: device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp File Reference
device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp File Reference

Go to the source code of this file.

Classes

struct  ck::tensor_operation::device::DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< A0Layout, B0Layout, D0sLayout, B1Layout, D1sLayout, E1Layout, A0DataType, B0DataType, Acc0DataType, D0sDataType, B1DataType, Acc1DataType, C1ShuffleDataType, D1sDataType, E1DataType, A0ElementwiseOperation, B0ElementwiseOperation, CDE0ElementwiseOperation, B1ElementwiseOperation, CDE1ElementwiseOperation, PadGemm0M, PadGemm0N, PadGemm0K, PadGemm1N, PadGemm1K, NumGemm0KPrefetchStage, BlockSize, Gemm0MPerBlock, Gemm0NPerBlock, Gemm0KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, A0K1, B0K1, B1K1, Gemm0MPerXdl, Gemm0NPerXdl, Gemm0MXdlPerWave, Gemm0NXdlPerWave, Gemm1NXdlPerWave, A0BlockTransferThreadClusterLengths_AK0_M_AK1, A0BlockTransferThreadClusterArrangeOrder, A0BlockTransferSrcAccessOrder, A0BlockTransferSrcVectorDim, A0BlockTransferSrcScalarPerVector, A0BlockTransferDstScalarPerVector_AK1, A0BlockLdsExtraM, B0BlockTransferThreadClusterLengths_BK0_N_BK1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_BK1, B0BlockLdsExtraN, CDE0BlockTransferSrcVectorDim, CDE0BlockTransferSrcScalaerPerVector, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, B1BlockLdsExtraN, C1ShuffleMXdlPerWavePerShuffle, C1ShuffleGemm0NXdlPerWavePerShuffle, CDE1ShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDE1ShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched >
struct  ck::tensor_operation::device::DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< A0Layout, B0Layout, D0sLayout, B1Layout, D1sLayout, E1Layout, A0DataType, B0DataType, Acc0DataType, D0sDataType, B1DataType, Acc1DataType, C1ShuffleDataType, D1sDataType, E1DataType, A0ElementwiseOperation, B0ElementwiseOperation, CDE0ElementwiseOperation, B1ElementwiseOperation, CDE1ElementwiseOperation, PadGemm0M, PadGemm0N, PadGemm0K, PadGemm1N, PadGemm1K, NumGemm0KPrefetchStage, BlockSize, Gemm0MPerBlock, Gemm0NPerBlock, Gemm0KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, A0K1, B0K1, B1K1, Gemm0MPerXdl, Gemm0NPerXdl, Gemm0MXdlPerWave, Gemm0NXdlPerWave, Gemm1NXdlPerWave, A0BlockTransferThreadClusterLengths_AK0_M_AK1, A0BlockTransferThreadClusterArrangeOrder, A0BlockTransferSrcAccessOrder, A0BlockTransferSrcVectorDim, A0BlockTransferSrcScalarPerVector, A0BlockTransferDstScalarPerVector_AK1, A0BlockLdsExtraM, B0BlockTransferThreadClusterLengths_BK0_N_BK1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_BK1, B0BlockLdsExtraN, CDE0BlockTransferSrcVectorDim, CDE0BlockTransferSrcScalaerPerVector, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, B1BlockLdsExtraN, C1ShuffleMXdlPerWavePerShuffle, C1ShuffleGemm0NXdlPerWavePerShuffle, CDE1ShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDE1ShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched >::ComputeBasePtrOfStridedBatch
struct  ck::tensor_operation::device::DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< A0Layout, B0Layout, D0sLayout, B1Layout, D1sLayout, E1Layout, A0DataType, B0DataType, Acc0DataType, D0sDataType, B1DataType, Acc1DataType, C1ShuffleDataType, D1sDataType, E1DataType, A0ElementwiseOperation, B0ElementwiseOperation, CDE0ElementwiseOperation, B1ElementwiseOperation, CDE1ElementwiseOperation, PadGemm0M, PadGemm0N, PadGemm0K, PadGemm1N, PadGemm1K, NumGemm0KPrefetchStage, BlockSize, Gemm0MPerBlock, Gemm0NPerBlock, Gemm0KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, A0K1, B0K1, B1K1, Gemm0MPerXdl, Gemm0NPerXdl, Gemm0MXdlPerWave, Gemm0NXdlPerWave, Gemm1NXdlPerWave, A0BlockTransferThreadClusterLengths_AK0_M_AK1, A0BlockTransferThreadClusterArrangeOrder, A0BlockTransferSrcAccessOrder, A0BlockTransferSrcVectorDim, A0BlockTransferSrcScalarPerVector, A0BlockTransferDstScalarPerVector_AK1, A0BlockLdsExtraM, B0BlockTransferThreadClusterLengths_BK0_N_BK1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_BK1, B0BlockLdsExtraN, CDE0BlockTransferSrcVectorDim, CDE0BlockTransferSrcScalaerPerVector, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, B1BlockLdsExtraN, C1ShuffleMXdlPerWavePerShuffle, C1ShuffleGemm0NXdlPerWavePerShuffle, CDE1ShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDE1ShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched >::Argument
struct  ck::tensor_operation::device::DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< A0Layout, B0Layout, D0sLayout, B1Layout, D1sLayout, E1Layout, A0DataType, B0DataType, Acc0DataType, D0sDataType, B1DataType, Acc1DataType, C1ShuffleDataType, D1sDataType, E1DataType, A0ElementwiseOperation, B0ElementwiseOperation, CDE0ElementwiseOperation, B1ElementwiseOperation, CDE1ElementwiseOperation, PadGemm0M, PadGemm0N, PadGemm0K, PadGemm1N, PadGemm1K, NumGemm0KPrefetchStage, BlockSize, Gemm0MPerBlock, Gemm0NPerBlock, Gemm0KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, A0K1, B0K1, B1K1, Gemm0MPerXdl, Gemm0NPerXdl, Gemm0MXdlPerWave, Gemm0NXdlPerWave, Gemm1NXdlPerWave, A0BlockTransferThreadClusterLengths_AK0_M_AK1, A0BlockTransferThreadClusterArrangeOrder, A0BlockTransferSrcAccessOrder, A0BlockTransferSrcVectorDim, A0BlockTransferSrcScalarPerVector, A0BlockTransferDstScalarPerVector_AK1, A0BlockLdsExtraM, B0BlockTransferThreadClusterLengths_BK0_N_BK1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_BK1, B0BlockLdsExtraN, CDE0BlockTransferSrcVectorDim, CDE0BlockTransferSrcScalaerPerVector, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, B1BlockLdsExtraN, C1ShuffleMXdlPerWavePerShuffle, C1ShuffleGemm0NXdlPerWavePerShuffle, CDE1ShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDE1ShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched >::Invoker

Namespaces

namespace  ck
namespace  ck::tensor_operation
namespace  ck::tensor_operation::device

Functions

template<typename GridwiseGemm, typename A0B0B1DataType, typename D0sPointer, typename D1sPointer, typename E1DataType, typename A0ElementwiseOperation, typename B0ElementwiseOperation, typename CDE0ElementwiseOperation, typename B1ElementwiseOperation, typename CDE1ElementwiseOperation, typename A0GridDesc_AK0_M_AK1, typename B0GridDesc_BK0_N_BK1, typename D0sGridDesc_M_N, typename B1GridDesc_BK0_N_BK1, typename D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename Block2E1TileMap, typename ComputeBasePtrOfStridedBatch, bool HasMainKBlockLoop>
__global__ void ck::tensor_operation::device::kernel_batched_gemm_gemm_xdl_cshuffle_v1 (const A0B0B1DataType *__restrict__ p_a0_grid, const A0B0B1DataType *__restrict__ p_b0_grid, D0sPointer p_d0s_grid, const A0B0B1DataType *__restrict__ p_b1_grid, D1sPointer p_d1s_grid, E1DataType *__restrict__ p_e1_grid, const A0ElementwiseOperation a0_element_op, const B0ElementwiseOperation b0_element_op, const CDE0ElementwiseOperation cde0_element_op, const B1ElementwiseOperation b1_element_op, const CDE1ElementwiseOperation cde1_element_op, const A0GridDesc_AK0_M_AK1 a0_grid_desc_ak0_m_ak1, const B0GridDesc_BK0_N_BK1 b0_grid_desc_bk0_n_bk1, const D0sGridDesc_M_N d0s_griddesc_m_n, const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1, const D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock d1s_grid_desc_mblock_mperblock_nblock_nperblock, const E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock e1_grid_desc_mblock_mperblock_nblock_nperblock, const Block2E1TileMap block_2_e1tile_map, const index_t batch_count, const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch)