QuantGroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_, QuantType_ > Struct Template Reference

QuantGroupedGemmKernel&lt; TilePartitioner_, GemmPipeline_, EpiloguePipeline_, QuantType_ &gt; Struct Template Reference#

Composable Kernel: ck_tile::QuantGroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_, QuantType_ > Struct Template Reference
ck_tile::QuantGroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_, QuantType_ > Struct Template Reference

#include <grouped_gemm_quant_kernel.hpp>

Public Types

using Base = QuantGemmKernel<TilePartitioner_, GemmPipeline_, EpiloguePipeline_, QuantType_>
 Inject the UniversalGemmKernel base class to support execution of all necessary functions.
using TilePartitioner = remove_cvref_t<TilePartitioner_>
using GemmPipeline = remove_cvref_t<GemmPipeline_>
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>
using ALayout = remove_cvref_t<typename GemmPipeline::ALayout>
using BLayout = remove_cvref_t<typename GemmPipeline::BLayout>
using CLayout = remove_cvref_t<typename GemmPipeline::CLayout>
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>
 Specify the data type configurations for A, B, C/E.
using BDataType = remove_cvref_t<typename GemmPipeline::BDataType>
using CDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>
using AccDataType = remove_cvref_t<typename EpiloguePipeline::AccDataType>
using AQDataType
using BQDataType
using OffsetTile1DPartitioner = OffsettedTile1DPartitioner<TilePartitioner>
 ALayout and ADataType are expected to be scalars, not a tuple.
using Kernel

Public Member Functions

CK_TILE_DEVICE void Run (const QuantGroupedGemmKernelArgs &kargs, const tuple< index_t, index_t > &block_idx_2d, const index_t block_idx_z) const
template<bool U = UsePersistentKernel, typename = std::enable_if_t<U>, typename = void>
CK_TILE_DEVICE void operator() (const void CK_CONSTANT_ADDRESS_SPACE *gemm_descs_const, const index_t group_count) const

Static Public Member Functions

static CK_TILE_HOST const std::string GetName ()
static CK_TILE_HOST auto GetWorkSpaceSize (const std::vector< QuantGroupedGemmHostArgs > &gemm_descs) -> std::size_t
static CK_TILE_HOST auto GetWorkSpaceSize (index_t group_count) -> std::size_t
static CK_TILE_HOST auto BlockSize () -> dim3
static CK_TILE_HOST auto MaxOccupancyGridSize (const stream_config &s) -> dim3
 Get the maximum occupancy grid size for the persistent kernel on the current device.
static CK_TILE_HOST auto GridSize (const std::vector< QuantGroupedGemmHostArgs > &gemm_descs)
static CK_TILE_HOST auto MakeKargs (const std::vector< QuantGroupedGemmHostArgs > &gemm_descs) -> std::vector< QuantGemmTransKernelArg >
static CK_TILE_HOST bool IsSupportedArgument (const std::vector< QuantGemmTransKernelArg > &kargs)
static CK_TILE_HOST_DEVICE constexpr auto GetSmemSize () -> index_t
template<memory_operation_enum DstInMemOp = memory_operation_enum::set>
static CK_TILE_DEVICE void RunGemmWithPipelineSelection2LDS (const ADataType *a_ptr, const BDataType *b_ptr, const AQDataType *aq_ptr, const BQDataType *bq_ptr, CDataType *c_ptr, void *smem_ptr_0, void *smem_ptr_1, const QuantGroupedGemmKernelArgs &kargs, const typename Base::SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
static CK_TILE_DEVICE void RunGemmWithPipelineSelection (const ADataType *a_ptr, const BDataType *b_ptr, const AQDataType *aq_ptr, const BQDataType *bq_ptr, CDataType *c_ptr, void *smem_ptr_0, const QuantGroupedGemmKernelArgs &kargs, const typename Base::SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
 Runs single GEMM problem cooperatively by whole workgroup.

Static Public Attributes

static constexpr auto kQuantType = QuantType_
static constexpr index_t kBlockSize = GemmPipeline::BlockSize
static constexpr bool UsePersistentKernel = GemmPipeline::UsePersistentKernel

Member Typedef Documentation

◆ AccDataType

template<typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_, QuantType QuantType_>
using ck_tile::QuantGroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_, QuantType_ >::AccDataType = remove_cvref_t<typename EpiloguePipeline::AccDataType>

◆ ADataType

template<typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_, QuantType QuantType_>
using ck_tile::QuantGroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_, QuantType_ >::ADataType = remove_cvref_t<typename GemmPipeline::ADataType>

Specify the data type configurations for A, B, C/E.

◆ ALayout

template<typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_, QuantType QuantType_>
using ck_tile::QuantGroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_, QuantType_ >::ALayout = remove_cvref_t<typename GemmPipeline::ALayout>

◆ AQDataType

template<typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_, QuantType QuantType_>
using ck_tile::QuantGroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_, QuantType_ >::AQDataType
Initial value:
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21

◆ Base

template<typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_, QuantType QuantType_>
using ck_tile::QuantGroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_, QuantType_ >::Base = QuantGemmKernel<TilePartitioner_, GemmPipeline_, EpiloguePipeline_, QuantType_>

Inject the UniversalGemmKernel base class to support execution of all necessary functions.

◆ BDataType

template<typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_, QuantType QuantType_>
using ck_tile::QuantGroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_, QuantType_ >::BDataType = remove_cvref_t<typename GemmPipeline::BDataType>

◆ BLayout

template<typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_, QuantType QuantType_>
using ck_tile::QuantGroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_, QuantType_ >::BLayout = remove_cvref_t<typename GemmPipeline::BLayout>

◆ BQDataType

template<typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_, QuantType QuantType_>
using ck_tile::QuantGroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_, QuantType_ >::BQDataType

◆ CDataType

template<typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_, QuantType QuantType_>
using ck_tile::QuantGroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_, QuantType_ >::CDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>

◆ CLayout

template<typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_, QuantType QuantType_>
using ck_tile::QuantGroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_, QuantType_ >::CLayout = remove_cvref_t<typename GemmPipeline::CLayout>

◆ EpiloguePipeline

template<typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_, QuantType QuantType_>
using ck_tile::QuantGroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_, QuantType_ >::EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>

◆ GemmPipeline

template<typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_, QuantType QuantType_>
using ck_tile::QuantGroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_, QuantType_ >::GemmPipeline = remove_cvref_t<GemmPipeline_>

◆ Kernel

template<typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_, QuantType QuantType_>
using ck_tile::QuantGroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_, QuantType_ >::Kernel

◆ OffsetTile1DPartitioner

template<typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_, QuantType QuantType_>
using ck_tile::QuantGroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_, QuantType_ >::OffsetTile1DPartitioner = OffsettedTile1DPartitioner<TilePartitioner>

ALayout and ADataType are expected to be scalars, not a tuple.

BLayout and BDataType are expected to be scalars, not a tuple.

C/ELayout and C/EDataType are expected to be scalars, not a tuple.

◆ TilePartitioner

template<typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_, QuantType QuantType_>
using ck_tile::QuantGroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_, QuantType_ >::TilePartitioner = remove_cvref_t<TilePartitioner_>

Member Function Documentation

◆ BlockSize()

template<typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_, QuantType QuantType_>
CK_TILE_HOST auto ck_tile::QuantGroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_, QuantType_ >::BlockSize ( ) ->dim3
inlinestatic

◆ GetName()

template<typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_, QuantType QuantType_>
CK_TILE_HOST const std::string ck_tile::QuantGroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_, QuantType_ >::GetName ( )
inlinestaticnodiscard

◆ GetSmemSize()

template<typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_, QuantType QuantType_>
CK_TILE_HOST_DEVICE constexpr auto ck_tile::QuantGroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_, QuantType_ >::GetSmemSize ( ) ->index_t
inlinestaticconstexpr

◆ GetWorkSpaceSize() [1/2]

template<typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_, QuantType QuantType_>
CK_TILE_HOST auto ck_tile::QuantGroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_, QuantType_ >::GetWorkSpaceSize ( const std::vector< QuantGroupedGemmHostArgs > & gemm_descs) ->std::size_t
inlinestatic

◆ GetWorkSpaceSize() [2/2]

template<typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_, QuantType QuantType_>
CK_TILE_HOST auto ck_tile::QuantGroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_, QuantType_ >::GetWorkSpaceSize ( index_t group_count) ->std::size_t
inlinestatic

◆ GridSize()

template<typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_, QuantType QuantType_>
CK_TILE_HOST auto ck_tile::QuantGroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_, QuantType_ >::GridSize ( const std::vector< QuantGroupedGemmHostArgs > & gemm_descs)
inlinestatic

◆ IsSupportedArgument()

template<typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_, QuantType QuantType_>
CK_TILE_HOST bool ck_tile::QuantGroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_, QuantType_ >::IsSupportedArgument ( const std::vector< QuantGemmTransKernelArg > & kargs)
inlinestatic

◆ MakeKargs()

template<typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_, QuantType QuantType_>
CK_TILE_HOST auto ck_tile::QuantGroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_, QuantType_ >::MakeKargs ( const std::vector< QuantGroupedGemmHostArgs > & gemm_descs) ->std::vector< QuantGemmTransKernelArg >
inlinestatic

◆ MaxOccupancyGridSize()

template<typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_, QuantType QuantType_>
CK_TILE_HOST auto ck_tile::QuantGroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_, QuantType_ >::MaxOccupancyGridSize ( const stream_config & s) ->dim3
inlinestatic

Get the maximum occupancy grid size for the persistent kernel on the current device.

Returns
The maximum occupancy grid size.
Note
This function queries the maximum occupancy of the kernel using hipOccupancyMaxActiveBlocksPerMultiprocessor.

◆ operator()()

template<typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_, QuantType QuantType_>
template<bool U = UsePersistentKernel, typename = std::enable_if_t<U>, typename = void>
CK_TILE_DEVICE void ck_tile::QuantGroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_, QuantType_ >::operator() ( const void CK_CONSTANT_ADDRESS_SPACE * gemm_descs_const,
const index_t group_count ) const
inline

◆ Run()

template<typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_, QuantType QuantType_>
CK_TILE_DEVICE void ck_tile::QuantGroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_, QuantType_ >::Run ( const QuantGroupedGemmKernelArgs & kargs,
const tuple< index_t, index_t > & block_idx_2d,
const index_t block_idx_z ) const
inline

◆ RunGemmWithPipelineSelection()

template<typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_, QuantType QuantType_>
CK_TILE_DEVICE void ck_tile::QuantGroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_, QuantType_ >::RunGemmWithPipelineSelection ( const ADataType * a_ptr,
const BDataType * b_ptr,
const AQDataType * aq_ptr,
const BQDataType * bq_ptr,
CDataType * c_ptr,
void * smem_ptr_0,
const QuantGroupedGemmKernelArgs & kargs,
const typename Base::SplitKBatchOffset & splitk_batch_offset,
const index_t block_idx_m,
const index_t block_idx_n )
inlinestatic

Runs single GEMM problem cooperatively by whole workgroup.

Note
The GEMM pipeline is selected in-kernel based on the number of K-loops and the tail-number. This is needed for the persistent tile-loop when we didn't have access to the K dimension on the host.
Parameters
a_ptrinput A pointer
b_ptrinput B pointer
aq_ptrinput AQ pointer
bq_ptrinput BQ pointer
c_ptroutput C pointer
smem_ptr_0The start memory pointer of the shared memory block.
kargsGEMM kernel arguments
splitk_batch_offsetsplitk_batch_offset Utility structure used to calculate k batch.
block_idx_mThe GEMM's output M dimension tile index processed by this workgroup.
block_idx_nThe GEMM's output N dimension tile index processed by this workgroup.

◆ RunGemmWithPipelineSelection2LDS()

template<typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_, QuantType QuantType_>
template<memory_operation_enum DstInMemOp = memory_operation_enum::set>
CK_TILE_DEVICE void ck_tile::QuantGroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_, QuantType_ >::RunGemmWithPipelineSelection2LDS ( const ADataType * a_ptr,
const BDataType * b_ptr,
const AQDataType * aq_ptr,
const BQDataType * bq_ptr,
CDataType * c_ptr,
void * smem_ptr_0,
void * smem_ptr_1,
const QuantGroupedGemmKernelArgs & kargs,
const typename Base::SplitKBatchOffset & splitk_batch_offset,
const index_t block_idx_m,
const index_t block_idx_n )
inlinestatic

Member Data Documentation

◆ kBlockSize

template<typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_, QuantType QuantType_>
index_t ck_tile::QuantGroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_, QuantType_ >::kBlockSize = GemmPipeline::BlockSize
staticconstexpr

◆ kQuantType

template<typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_, QuantType QuantType_>
auto ck_tile::QuantGroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_, QuantType_ >::kQuantType = QuantType_
staticconstexpr

◆ UsePersistentKernel

template<typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_, QuantType QuantType_>
bool ck_tile::QuantGroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_, QuantType_ >::UsePersistentKernel = GemmPipeline::UsePersistentKernel
staticconstexpr

The documentation for this struct was generated from the following file: