BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR< Problem, Policy > Struct Template Reference#
Public Types |
Public Member Functions |
Static Public Member Functions |
Static Public Attributes |
List of all members
ck_tile::BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR< Problem, Policy > Struct Template Reference
#include <block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp>
Public Types | |
| using | QDataType = remove_cvref_t<typename Problem::QDataType> |
| using | KDataType = remove_cvref_t<typename Problem::KDataType> |
| using | VDataType = remove_cvref_t<typename Problem::VDataType> |
| using | GemmDataType = remove_cvref_t<typename Problem::GemmDataType> |
| using | BiasDataType = remove_cvref_t<typename Problem::BiasDataType> |
| using | LSEDataType = remove_cvref_t<typename Problem::LSEDataType> |
| using | AccDataType = remove_cvref_t<typename Problem::AccDataType> |
| using | DDataType = remove_cvref_t<typename Problem::DDataType> |
| using | RandValOutputDataType = remove_cvref_t<typename Problem::RandValOutputDataType> |
| using | ODataType = remove_cvref_t<typename Problem::ODataType> |
| using | OGradDataType = remove_cvref_t<typename Problem::OGradDataType> |
| using | QGradDataType = remove_cvref_t<typename Problem::QGradDataType> |
| using | KGradDataType = remove_cvref_t<typename Problem::KGradDataType> |
| using | VGradDataType = remove_cvref_t<typename Problem::VGradDataType> |
| using | BiasGradDataType = remove_cvref_t<typename Problem::BiasGradDataType> |
| using | FmhaMask = remove_cvref_t<typename Problem::FmhaMask> |
| using | FmhaDropout = remove_cvref_t<typename Problem::FmhaDropout> |
| using | BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape> |
Public Member Functions | |
| template<typename... Ts> | |
| CK_TILE_DEVICE auto | operator() (void *smem_ptr, Ts &&... args) const |
| template<typename QDramBlockWindowTmp, typename KDramBlockWindowTmp, typename VDramBlockWindowTmp, typename BiasDramBlockWindowTmp, typename RandValDramBlockWindowTmp, typename OGradDramBlockWindowTmp, typename LSEDramBlockWindowTmp, typename DDramBlockWindowTmp, typename QGradDramBlockWindowTmp, typename BiasGradDramBlockWindowTmp, typename PositionEncoding> | |
| CK_TILE_DEVICE auto | run (KDataType *__restrict__ k_lds_ptr, VDataType *__restrict__ v_lds_ptr, OGradDataType *__restrict__ do_lds_ptr0, OGradDataType *__restrict__ do_lds_ptr1, QDataType *__restrict__ q_lds_ptr0, QDataType *__restrict__ q_lds_ptr1, LSEDataType *__restrict__ lse_lds_ptr0, LSEDataType *__restrict__ lse_lds_ptr1, DDataType *__restrict__ d_lds_ptr0, DDataType *__restrict__ d_lds_ptr1, GemmDataType *__restrict__ ds_lds_ptr, BiasDataType *__restrict__ bias_lds_ptr, const QDramBlockWindowTmp &q_dram_block_window_tmp, const KDramBlockWindowTmp &k_dram_block_window_tmp, const VDramBlockWindowTmp &v_dram_block_window_tmp, const BiasDramBlockWindowTmp &bias_dram_block_window_tmp, const RandValDramBlockWindowTmp &randval_dram_block_window_tmp, const OGradDramBlockWindowTmp &do_dram_block_window_tmp, const LSEDramBlockWindowTmp &lse_dram_block_window_tmp, const DDramBlockWindowTmp &d_dram_block_window_tmp, const QGradDramBlockWindowTmp &dq_dram_block_window_tmp, const BiasGradDramBlockWindowTmp &dbias_dram_block_window_tmp, FmhaMask mask, PositionEncoding position_encoding, float raw_scale, float scale, float rp_undrop, float scale_rp_undrop, FmhaDropout &dropout) const |
Static Public Member Functions | |
| static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t | GetSmemSize () |
| static CK_TILE_HOST_DEVICE LSEDataType | get_validated_lse (const LSEDataType raw_lse) |
Static Public Attributes | |
| static constexpr index_t | kBlockPerCu = Problem::kBlockPerCu |
| static constexpr index_t | kBlockSize = Problem::kBlockSize |
| static constexpr index_t | kM0 = BlockFmhaShape::kM0 |
| static constexpr index_t | kN0 = BlockFmhaShape::kN0 |
| static constexpr index_t | kK0 = BlockFmhaShape::kK0 |
| static constexpr index_t | kK1 = BlockFmhaShape::kK1 |
| static constexpr index_t | kK2 = BlockFmhaShape::kK2 |
| static constexpr index_t | kK3 = BlockFmhaShape::kK3 |
| static constexpr index_t | kK4 = BlockFmhaShape::kK4 |
| static constexpr index_t | kQKHeaddim = BlockFmhaShape::kQKHeaddim |
| static constexpr index_t | kVHeaddim = BlockFmhaShape::kVHeaddim |
| static constexpr bool | kIsGroupMode = Problem::kIsGroupMode |
| static constexpr index_t | kPadHeadDimQ = Problem::kPadHeadDimQ |
| static constexpr index_t | kPadHeadDimV = Problem::kPadHeadDimV |
| static constexpr auto | BiasEnum = Problem::BiasEnum |
| static constexpr bool | kHasBiasGrad = Problem::kHasBiasGrad |
| static constexpr bool | kIsDeterministic = Problem::kIsDeterministic |
| static constexpr bool | kUseTrLoad = Problem::kUseTrLoad |
| static constexpr index_t | kAlignmentQ |
| static constexpr index_t | kAlignmentK |
| static constexpr index_t | kAlignmentV |
| static constexpr index_t | kAlignmentOGrad |
| static constexpr index_t | kAlignmentQGrad = 1 |
| static constexpr index_t | kAlignmentKGrad |
| static constexpr index_t | kAlignmentVGrad |
| static constexpr index_t | kAlignmentBias = 1 |
| static constexpr const char * | name = "trload_kr_ktr_vr" |
Member Typedef Documentation
◆ AccDataType
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
| using ck_tile::BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR< Problem, Policy >::AccDataType = remove_cvref_t<typename Problem::AccDataType> |
◆ BiasDataType
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
| using ck_tile::BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR< Problem, Policy >::BiasDataType = remove_cvref_t<typename Problem::BiasDataType> |
◆ BiasGradDataType
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
| using ck_tile::BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR< Problem, Policy >::BiasGradDataType = remove_cvref_t<typename Problem::BiasGradDataType> |
◆ BlockFmhaShape
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
| using ck_tile::BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR< Problem, Policy >::BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape> |
◆ DDataType
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
| using ck_tile::BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR< Problem, Policy >::DDataType = remove_cvref_t<typename Problem::DDataType> |
◆ FmhaDropout
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
| using ck_tile::BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR< Problem, Policy >::FmhaDropout = remove_cvref_t<typename Problem::FmhaDropout> |
◆ FmhaMask
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
| using ck_tile::BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR< Problem, Policy >::FmhaMask = remove_cvref_t<typename Problem::FmhaMask> |
◆ GemmDataType
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
| using ck_tile::BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR< Problem, Policy >::GemmDataType = remove_cvref_t<typename Problem::GemmDataType> |
◆ KDataType
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
| using ck_tile::BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR< Problem, Policy >::KDataType = remove_cvref_t<typename Problem::KDataType> |
◆ KGradDataType
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
| using ck_tile::BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR< Problem, Policy >::KGradDataType = remove_cvref_t<typename Problem::KGradDataType> |
◆ LSEDataType
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
| using ck_tile::BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR< Problem, Policy >::LSEDataType = remove_cvref_t<typename Problem::LSEDataType> |
◆ ODataType
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
| using ck_tile::BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR< Problem, Policy >::ODataType = remove_cvref_t<typename Problem::ODataType> |
◆ OGradDataType
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
| using ck_tile::BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR< Problem, Policy >::OGradDataType = remove_cvref_t<typename Problem::OGradDataType> |
◆ QDataType
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
| using ck_tile::BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR< Problem, Policy >::QDataType = remove_cvref_t<typename Problem::QDataType> |
◆ QGradDataType
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
| using ck_tile::BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR< Problem, Policy >::QGradDataType = remove_cvref_t<typename Problem::QGradDataType> |
◆ RandValOutputDataType
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
| using ck_tile::BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR< Problem, Policy >::RandValOutputDataType = remove_cvref_t<typename Problem::RandValOutputDataType> |
◆ VDataType
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
| using ck_tile::BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR< Problem, Policy >::VDataType = remove_cvref_t<typename Problem::VDataType> |
◆ VGradDataType
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
| using ck_tile::BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR< Problem, Policy >::VGradDataType = remove_cvref_t<typename Problem::VGradDataType> |
Member Function Documentation
◆ get_validated_lse()
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
|
inlinestatic |
◆ GetSmemSize()
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
|
inlinestaticconstexpr |
◆ operator()()
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
template<typename... Ts>
|
inline |
◆ run()
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
template<typename QDramBlockWindowTmp, typename KDramBlockWindowTmp, typename VDramBlockWindowTmp, typename BiasDramBlockWindowTmp, typename RandValDramBlockWindowTmp, typename OGradDramBlockWindowTmp, typename LSEDramBlockWindowTmp, typename DDramBlockWindowTmp, typename QGradDramBlockWindowTmp, typename BiasGradDramBlockWindowTmp, typename PositionEncoding>
|
inline |
Member Data Documentation
◆ BiasEnum
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
|
staticconstexpr |
◆ kAlignmentBias
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
|
staticconstexpr |
◆ kAlignmentK
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
|
staticconstexpr |
Initial value:
=
kPadHeadDimQ ? kPadHeadDimQ : Policy::template GetAlignmentK<Problem>()
static constexpr index_t kPadHeadDimQ
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:52
◆ kAlignmentKGrad
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
|
staticconstexpr |
Initial value:
=
kPadHeadDimQ ? kPadHeadDimQ : Policy::template GetAlignmentKGrad<Problem>()
◆ kAlignmentOGrad
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
|
staticconstexpr |
Initial value:
=
kPadHeadDimV ? kPadHeadDimV : Policy::template GetAlignmentOGrad<Problem>()
static constexpr bool kPadHeadDimV
Definition block_fmha_bwd_dot_do_o.hpp:24
◆ kAlignmentQ
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
|
staticconstexpr |
Initial value:
=
kPadHeadDimQ ? kPadHeadDimQ : Policy::template GetAlignmentQ<Problem>()
◆ kAlignmentQGrad
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
|
staticconstexpr |
◆ kAlignmentV
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
|
staticconstexpr |
Initial value:
=
kPadHeadDimV ? kPadHeadDimV : Policy::template GetAlignmentV<Problem>()
static constexpr index_t kPadHeadDimV
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:53
◆ kAlignmentVGrad
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
|
staticconstexpr |
Initial value:
=
kPadHeadDimV ? kPadHeadDimV : Policy::template GetAlignmentVGrad<Problem>()
◆ kBlockPerCu
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
|
staticconstexpr |
◆ kBlockSize
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
|
staticconstexpr |
◆ kHasBiasGrad
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
|
staticconstexpr |
◆ kIsDeterministic
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
|
staticconstexpr |
◆ kIsGroupMode
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
|
staticconstexpr |
◆ kK0
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
|
staticconstexpr |
◆ kK1
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
|
staticconstexpr |
◆ kK2
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
|
staticconstexpr |
◆ kK3
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
|
staticconstexpr |
◆ kK4
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
|
staticconstexpr |
◆ kM0
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
|
staticconstexpr |
◆ kN0
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
|
staticconstexpr |
◆ kPadHeadDimQ
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
|
staticconstexpr |
◆ kPadHeadDimV
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
|
staticconstexpr |
◆ kQKHeaddim
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
|
staticconstexpr |
◆ kUseTrLoad
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
|
staticconstexpr |
◆ kVHeaddim
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
|
staticconstexpr |
◆ name
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
|
staticconstexpr |
The documentation for this struct was generated from the following file: