block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp Source File#
block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp
Go to the documentation of this file.
8#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp"
14template <typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy>
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc &in_element_func, const InTensor &... in_dstr_tensors)
Definition tile_elementwise.hpp:40
CK_TILE_DEVICE void set_tile(DstrTensors &dstr_tensor, const T &value)
Definition tile_elementwise.hpp:95
CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType *__restrict__ p, const tensor_descriptor< Ts... > &desc)
Definition tensor_view.hpp:452
CK_TILE_HOST_DEVICE constexpr auto get_x_indices_from_distributed_indices(StaticTileDistribution tile_distribution, DistributedIndices distributed_indices)
Definition static_distributed_tensor.hpp:159
CK_TILE_DEVICE constexpr auto get_slice_tile(const tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile, sequence< SliceBegins... > slice_begins, sequence< SliceEnds... > slice_ends)
Definition slice_tile.hpp:23
@ ELEMENTWISE_BIAS
Definition block_attention_bias_enum.hpp:14
constant< b > bool_constant
Definition tile/core/numeric/integral_constant.hpp:43
CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_ &acc_tensor, const ReduceFunc &reduce_func, bool_constant< WithBroadcast >={}, bool_constant< CrossWarp >={})
Definition block_reduce.hpp:21
CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc &inout_element_func, InOutDstrTensors &... inout_dstr_tensors)
Definition tile_elementwise.hpp:23
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution &)
Definition static_distributed_tensor.hpp:142
CK_TILE_DEVICE void shuffle_tile(OutTensor &out, const InTensor &in)
Definition shuffle_tile.hpp:154
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
CK_TILE_DEVICE auto cast_tile(const SrcTensor &src_tensor)
Definition tile_elementwise.hpp:327
CK_TILE_DEVICE void block_tile_reduce(AccDistributedTensor_ &acc_tensor, const InDistributedTensor_ &in_tensor, sequence< InReduceDims... >, const ReduceFunc &reduce_func)
Definition block_reduce.hpp:191
CK_TILE_HOST_DEVICE constexpr auto integer_divide_ceil(X x, Y y)
Definition tile/core/numeric/math.hpp:149
CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F &f)
Definition sweep_tile.hpp:20
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition null_tile_window.hpp:95
CK_TILE_HOST_DEVICE void set_tile_if(static_distributed_tensor< DataType, StaticTileDistribution > &out_tensor, DataType value, XIndicesPredicate predicate)
Definition static_distributed_tensor.hpp:175
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition store_tile.hpp:23
CK_TILE_DEVICE void clear_tile(DstrTensors &dstr_tensor)
Definition tile_elementwise.hpp:177
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
Definition tile/core/numeric/type_convert.hpp:29
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition load_tile.hpp:22
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
tuple_array< T, N > statically_indexed_array
Definition tile/core/container/statically_indexed_array.hpp:16
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:16
remove_cvref_t< typename Problem::LSEDataType > LSEDataType
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:26
static constexpr index_t kM0
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:40
static constexpr index_t kBlockSize
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:38
CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp &q_dram_block_window_tmp, const QElementFunction &q_element_func, const KDramBlockWindowTmp &k_dram_block_window_tmp, const KElementFunction &k_element_func, const VDramBlockWindowTmp &v_dram_block_window_tmp, const VElementFunction &v_element_func, const BiasDramBlockWindowTmp &bias_dram_block_window_tmp, const BiasElementFunction &bias_element_func, RandValDramBlockWindowTmp &randval_dram_block_window_tmp, LSEDramBlockWindowTmp &lse_dram_window_tmp, const LSEElementFunction &lse_element_func, const SAccElementFunction &s_acc_element_func, const PComputeElementFunction &p_compute_element_func, const OAccElementFunction &o_acc_element_func, FmhaMask mask, PositionEncoding position_encoding, float scale_s, const AttentionVariant &, const AttentionVariantParams &, const BlockIndices &, void *smem_ptr, DropoutType &dropout) const
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:136
remove_cvref_t< typename BlockFmhaShape::VLayout > VLayout
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:34
static constexpr bool kIsGroupMode
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:50
static constexpr index_t kN1
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:43
remove_cvref_t< typename Problem::QDataType > QDataType
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:19
static constexpr bool kPadHeadDimV
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:54
static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t GetSmemSize()
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:113
static constexpr index_t kAlignmentO
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:73
remove_cvref_t< typename Problem::AttentionVariant > AttentionVariant
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:31
static constexpr bool kPadHeadDimQ
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:53
static constexpr index_t kAlignmentQ
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:62
static constexpr bool kPadSeqLenK
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:52
static constexpr index_t kN0
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:41
static constexpr index_t kAlignmentBias
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:75
static constexpr bool kQLoadOnce
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:35
remove_cvref_t< typename Problem::ODataType > ODataType
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:29
static constexpr index_t kK0
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:42
remove_cvref_t< typename Problem::BlockFmhaShape > BlockFmhaShape
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:33
static constexpr index_t kK1
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:44
remove_cvref_t< typename Problem::PDataType > PDataType
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:27
remove_cvref_t< typename Problem::VDataType > VDataType
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:21
remove_cvref_t< typename Problem::BiasDataType > BiasDataType
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:24
static constexpr index_t kAlignmentV
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:66
static constexpr index_t kQKHeaddim
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:45
static constexpr bool kHasDropout
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:57
static constexpr bool kStoreLSE
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:56
remove_cvref_t< typename Problem::SaccDataType > SaccDataType
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:22
remove_cvref_t< typename Problem::FmhaMask > FmhaMask
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:30
CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp &q_dram_block_window_tmp, const KDramBlockWindowTmp &k_dram_block_window_tmp, const VDramBlockWindowTmp &v_dram_block_window_tmp, const BiasDramBlockWindowTmp &bias_dram_block_window_tmp, RandValDramBlockWindowTmp &randval_dram_block_window_tmp, LSEDramBlockWindowTmp &lse_dram_block_window_tmp, FmhaMask mask, PositionEncoding position_encoding, float scale_s, const AttentionVariant &variant, const AttentionVariantParams &variant_params, const BlockIndices &block_indices, void *smem_ptr, DropoutType &dropout) const
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:904
remove_cvref_t< typename Problem::RandValOutputDataType > RandValOutputDataType
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:25
remove_cvref_t< Policy_ > Policy
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:18
static constexpr bool kHasLogitsSoftCap
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:58
static constexpr auto BiasEnum
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:55
static constexpr const char * name
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:109
remove_cvref_t< Problem_ > Problem
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:17
remove_cvref_t< typename Problem::KDataType > KDataType
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:20
static constexpr index_t kAlignmentK
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:64
static constexpr bool kPadSeqLenQ
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:51
static constexpr index_t kSubQKHeaddim
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:46
remove_cvref_t< typename Problem::OaccDataType > OaccDataType
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:28
static constexpr index_t kBlockPerCu
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:78
remove_cvref_t< typename Problem::SMPLComputeDataType > SMPLComputeDataType
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:23
std::conditional_t< kHasDropout, BlockDropout, NullBlockDropout > DropoutType
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:111
Definition tile/core/utility/functional.hpp:86
static CK_TILE_HOST_DEVICE constexpr T infinity()
Definition tile/core/numeric/numeric.hpp:38
Definition tile/core/container/sequence.hpp:49
Definition tile/core/utility/functional.hpp:43