block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp Source File#
block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp
Go to the documentation of this file.
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc &in_element_func, const InTensor &... in_dstr_tensors)
Definition tile_elementwise.hpp:40
CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType *__restrict__ p, const tensor_descriptor< Ts... > &desc)
Definition tensor_view.hpp:452
CK_TILE_HOST_DEVICE constexpr auto get_x_indices_from_distributed_indices(StaticTileDistribution tile_distribution, DistributedIndices distributed_indices)
Definition static_distributed_tensor.hpp:159
CK_TILE_DEVICE constexpr auto get_slice_tile(const tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile, sequence< SliceBegins... > slice_begins, sequence< SliceEnds... > slice_ends)
Definition slice_tile.hpp:23
@ ELEMENTWISE_BIAS
Definition block_attention_bias_enum.hpp:14
CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc &inout_element_func, InOutDstrTensors &... inout_dstr_tensors)
Definition tile_elementwise.hpp:23
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution &)
Definition static_distributed_tensor.hpp:142
CK_TILE_DEVICE void shuffle_tile(OutTensor &out, const InTensor &in)
Definition shuffle_tile.hpp:154
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
CK_TILE_DEVICE auto cast_tile(const SrcTensor &src_tensor)
Definition tile_elementwise.hpp:327
CK_TILE_HOST_DEVICE constexpr auto integer_divide_ceil(X x, Y y)
Definition tile/core/numeric/math.hpp:149
CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F &f)
Definition sweep_tile.hpp:20
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition null_tile_window.hpp:95
CK_TILE_HOST_DEVICE void set_tile_if(static_distributed_tensor< DataType, StaticTileDistribution > &out_tensor, DataType value, XIndicesPredicate predicate)
Definition static_distributed_tensor.hpp:175
CK_TILE_DEVICE void update_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition update_tile.hpp:22
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition store_tile.hpp:23
CK_TILE_DEVICE void clear_tile(DstrTensors &dstr_tensor)
Definition tile_elementwise.hpp:177
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
Definition tile/core/numeric/type_convert.hpp:29
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition load_tile.hpp:22
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:16
remove_cvref_t< typename Problem::DDataType > DDataType
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:24
static constexpr bool kHasBiasGrad
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:55
remove_cvref_t< typename Problem::AccDataType > AccDataType
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:23
static constexpr index_t kAlignmentQGrad
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:70
static constexpr index_t kAlignmentOGrad
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:68
static constexpr index_t kK1
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:44
static constexpr index_t kPadHeadDimV
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:53
remove_cvref_t< typename Problem::VDataType > VDataType
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:19
remove_cvref_t< typename Problem::FmhaMask > FmhaMask
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:32
CK_TILE_HOST_DEVICE auto operator()(void *smem_ptr, const QDramBlockWindowTmp &q_dram_block_window_tmp, const KDramBlockWindowTmp &k_dram_block_window_tmp, const VDramBlockWindowTmp &v_dram_block_window_tmp, const BiasDramBlockWindowTmp &bias_dram_block_window_tmp, const RandValDramBlockWindowTmp &randval_dram_block_window_tmp, const OGradDramBlockWindowTmp &do_dram_block_window_tmp, const LSEDramBlockWindowTmp &lse_dram_block_window_tmp, const DDramBlockWindowTmp &d_dram_block_window_tmp, const QGradDramBlockWindowTmp &dq_dram_block_window_tmp, const BiasGradDramBlockWindowTmp &dbias_dram_block_window_tmp, FmhaMask mask, PositionEncoding position_encoding, float raw_scale, float scale, float rp_undrop, float scale_rp_undrop, FmhaDropout &dropout) const
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:96
static constexpr index_t kVHeaddim
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:49
remove_cvref_t< typename Problem::VGradDataType > VGradDataType
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:30
remove_cvref_t< typename Problem::BiasGradDataType > BiasGradDataType
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:31
static constexpr index_t kAlignmentVGrad
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:73
static constexpr index_t kAlignmentK
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:64
static constexpr index_t kPadHeadDimQ
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:52
static constexpr index_t kQKHeaddim
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:48
remove_cvref_t< typename Problem::RandValOutputDataType > RandValOutputDataType
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:25
remove_cvref_t< typename Problem::KGradDataType > KGradDataType
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:29
static constexpr bool kIsDeterministic
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:56
static constexpr bool kUseTrLoad
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:57
static constexpr index_t kBlockPerCu
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:38
remove_cvref_t< typename Problem::OGradDataType > OGradDataType
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:27
static constexpr index_t kM0
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:41
remove_cvref_t< typename Problem::QGradDataType > QGradDataType
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:28
static constexpr index_t kK2
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:45
remove_cvref_t< typename Problem::QDataType > QDataType
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:17
static constexpr index_t kK4
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:47
remove_cvref_t< typename Problem::BlockFmhaShape > BlockFmhaShape
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:36
remove_cvref_t< typename Problem::FmhaDropout > FmhaDropout
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:33
static constexpr index_t kAlignmentQ
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:62
remove_cvref_t< typename Problem::ODataType > ODataType
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:26
typename Policy::template HotLoopScheduler< Problem > HotLoopScheduler
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:34
remove_cvref_t< typename Problem::GemmDataType > GemmDataType
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:20
static constexpr index_t kK3
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:46
remove_cvref_t< typename Problem::KDataType > KDataType
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:18
static constexpr index_t kAlignmentKGrad
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:71
static constexpr bool kIsGroupMode
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:51
static constexpr index_t kN0
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:42
static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t GetSmemSize()
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:79
remove_cvref_t< typename Problem::LSEDataType > LSEDataType
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:22
static constexpr auto BiasEnum
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:54
static constexpr index_t kBlockSize
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:39
remove_cvref_t< typename Problem::BiasDataType > BiasDataType
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:21
static constexpr index_t kAlignmentBias
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:75
static constexpr const char * name
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:77
static constexpr index_t kAlignmentV
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:66
static constexpr index_t kK0
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:43
static CK_TILE_HOST_DEVICE constexpr T infinity()
Definition tile/core/numeric/numeric.hpp:38
Definition tile/core/container/sequence.hpp:49
Definition tile/core/utility/functional.hpp:43