ComposedAttention< VARIANT_CODE, UseExp2 > Struct Template Reference

ComposedAttention&lt; VARIANT_CODE, UseExp2 &gt; Struct Template Reference#

Composable Kernel: ck_tile::ComposedAttention< VARIANT_CODE, UseExp2 > Struct Template Reference
ck_tile::ComposedAttention< VARIANT_CODE, UseExp2 > Struct Template Reference

#include <variants.hpp>

Public Member Functions

__device__ __host__ ComposedAttention ()=default
template<typename Params, typename T>
__device__ __forceinline__ T QueryTransform (const Params &params, T q) const
template<typename Params, typename T>
__device__ __forceinline__ T LogitsTransform (const Params &params, T logits, uint32_t batch_idx, uint32_t qo_head_idx, uint32_t kv_head_idx) const
template<typename Params>
__device__ __forceinline__ bool LogitsMask (const Params &params, uint32_t batch_idx, uint32_t qo_idx, uint32_t kv_idx, uint32_t qo_head_idx, uint32_t kv_head_idx) const

Static Public Attributes

static constexpr bool use_exp2 = UseExp2
static constexpr bool use_logits_soft_cap = (VARIANT_CODE & LOGITS_SOFT_CAP) != 0

Constructor & Destructor Documentation

◆ ComposedAttention()

template<uint32_t VARIANT_CODE, bool UseExp2 = false>
__device__ __host__ ck_tile::ComposedAttention< VARIANT_CODE, UseExp2 >::ComposedAttention ( )
default

Member Function Documentation

◆ LogitsMask()

template<uint32_t VARIANT_CODE, bool UseExp2 = false>
template<typename Params>
__device__ __forceinline__ bool ck_tile::ComposedAttention< VARIANT_CODE, UseExp2 >::LogitsMask ( const Params & params,
uint32_t batch_idx,
uint32_t qo_idx,
uint32_t kv_idx,
uint32_t qo_head_idx,
uint32_t kv_head_idx ) const
inline

◆ LogitsTransform()

template<uint32_t VARIANT_CODE, bool UseExp2 = false>
template<typename Params, typename T>
__device__ __forceinline__ T ck_tile::ComposedAttention< VARIANT_CODE, UseExp2 >::LogitsTransform ( const Params & params,
T logits,
uint32_t batch_idx,
uint32_t qo_head_idx,
uint32_t kv_head_idx ) const
inline

NOTICE: For better performance, we simpliy transform thread buffer without calculating qo_idx/kv_idx.

◆ QueryTransform()

template<uint32_t VARIANT_CODE, bool UseExp2 = false>
template<typename Params, typename T>
__device__ __forceinline__ T ck_tile::ComposedAttention< VARIANT_CODE, UseExp2 >::QueryTransform ( const Params & params,
T q ) const
inline

Member Data Documentation

◆ use_exp2

template<uint32_t VARIANT_CODE, bool UseExp2 = false>
bool ck_tile::ComposedAttention< VARIANT_CODE, UseExp2 >::use_exp2 = UseExp2
staticconstexpr

◆ use_logits_soft_cap

template<uint32_t VARIANT_CODE, bool UseExp2 = false>
bool ck_tile::ComposedAttention< VARIANT_CODE, UseExp2 >::use_logits_soft_cap = (VARIANT_CODE & LOGITS_SOFT_CAP) != 0
staticconstexpr

The documentation for this struct was generated from the following file: