wp_pipeline_agmem_bgmem_creg_base_policy.hpp Source File

wp_pipeline_agmem_bgmem_creg_base_policy.hpp Source File#

Composable Kernel: wp_pipeline_agmem_bgmem_creg_base_policy.hpp Source File
wp_pipeline_agmem_bgmem_creg_base_policy.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
8
9namespace ck_tile {
10
12 : public UniversalGemmBasePolicy<UniversalWeightPreshufflePipelineAgBgCrPolicy>
13{
15
16 // 3d + padding
17 template <typename Problem>
19 {
20 using namespace ck_tile;
21 constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
22 constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
23 constexpr index_t kKPack = GetSmemPackA<Problem>();
25
26 constexpr auto DataTypeSize = sizeof(ADataType);
27 constexpr auto MLdsLayer =
28 (32 * 4 / kKPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / kKPerBlock / DataTypeSize);
29
30 constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
32 number<kMPerBlock / MLdsLayer>{},
36 number<1>{});
37
38 constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
39 a_lds_block_desc_0,
41 number<kKPerBlock / kKPack * MLdsLayer>{})),
45
46 constexpr auto a_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor(
47 a_lds_block_desc_permuted,
49 make_tuple(number<MLdsLayer>{}, number<kKPerBlock / kKPack>{})),
54
55 constexpr auto a_lds_block_desc = transform_tensor_descriptor(
56 a_lds_block_desc_xk0_mnldslayer_mn_xk1,
63 return a_lds_block_desc;
64 }
65
66 template <typename Problem>
68 {
69 constexpr index_t smem_size_a = sizeof(typename Problem::ADataType) *
70 MakeALdsBlockDescriptor<Problem>().get_element_space_size();
71 return smem_size_a;
72 }
73
74 template <typename Problem>
76 {
77 constexpr index_t smem_size_a = GetSmemSizeA<Problem>();
78
79 return smem_size_a;
80 }
81
82 template <typename Problem>
83 CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackA()
84 {
85 return Problem::VectorLoadSize / sizeof(typename Problem::ADataType);
86 }
87
88 template <typename Problem>
89 CK_TILE_HOST_DEVICE static constexpr auto GetKBPerLoad()
90 {
91 using TileShape = typename Problem::BlockGemmShape;
92#if defined(__gfx11__)
93 constexpr index_t scale = 4;
94#else
95 constexpr index_t scale = get_warp_size() == 32 ? 2 : 1;
96#endif
97 if constexpr(TileShape::WarpTile::at(I1) == 32)
98 {
99 return TileShape::WarpTile::at(I2) * scale / 2;
100 }
101 else
102 {
103 static_assert(TileShape::WarpTile::at(I1) == 16);
104 return TileShape::WarpTile::at(I2) * scale / 4;
105 }
106 }
107
108 template <typename Problem>
110 {
113
114 constexpr index_t BlockSize = Problem::kBlockSize;
115
116 constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
117 constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
118
119 if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
120 {
121 constexpr index_t M1 = Problem::VectorLoadSize / sizeof(ADataType);
122 constexpr index_t M0 = MPerBlock / M1;
123 constexpr index_t total_pixels = MPerBlock * KPerBlock / BlockSize;
124 static_assert(total_pixels % M1 == 0);
125 constexpr index_t K3 = total_pixels / M1;
126 constexpr index_t KPack = GetSmemPackA<Problem>();
127 static_assert(KPack % K3 == 0);
128 constexpr index_t K2 = KPack / K3;
129 if constexpr(get_warp_size() >= (K2 * M0))
130 {
131 constexpr index_t K1 = get_warp_size() / (K2 * M0);
132 constexpr index_t K0 = BlockSize / get_warp_size();
133 static_assert(KPerBlock == K0 * K1 * K2 * K3);
140 sequence<3, 1>>{});
141 }
142 else
143 {
144 constexpr index_t K1 = (K2 * M0) / get_warp_size();
145 constexpr index_t K2_m = K2 / K1;
146 constexpr index_t K0 = BlockSize / get_warp_size() / K1;
147 static_assert(KPerBlock == K0 * K1 * K2_m * K3);
154 sequence<3, 1>>{});
155 }
156 }
157 else
158 {
159 constexpr index_t K1 = Problem::VectorLoadSize / sizeof(ADataType);
160 constexpr index_t K0 = KPerBlock / K1;
161 constexpr index_t M2 = get_warp_size() / K0;
162 // coalesce reading for each blocks
163 if constexpr(get_warp_size() % (M2 * K0) == 0)
164 {
165 constexpr index_t M1 = BlockSize / get_warp_size();
166 static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error.");
167 static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error.");
168 constexpr index_t M0 = MPerBlock / (M2 * M1);
169 static_assert(M0 * M1 * M2 == MPerBlock,
170 "Incorrect M0, M2, M1 configuration! "
171 "M0, M1, M2 must cover whole MPerBlock!");
172
179 sequence<0, 1>>{});
180 }
181 else
182 {
183 constexpr index_t M0 = BlockSize / get_warp_size();
184 constexpr index_t M1 = MPerBlock / (M2 * M0);
185 static_assert(M0 * M1 * M2 == MPerBlock,
186 "Incorrect M0, M1, M2 configuration! "
187 "M0, M1, M2 must cover whole MPerBlock!");
194 sequence<1, 1>>{});
195 }
196 }
197 }
198
199 template <typename Problem>
201 {
202 using TileShape = typename Problem::BlockGemmShape;
203
204 constexpr index_t BlockSize = Problem::kBlockSize;
205 constexpr index_t WaveSize = get_warp_size();
206 constexpr index_t WaveNum = BlockSize / WaveSize;
207
208 constexpr index_t KBPerLoad = GetKBPerLoad<Problem>();
209#if defined(__gfx11__)
210 constexpr index_t KRepeatInWave = 2;
211#else
212 constexpr index_t KRepeatInWave = 1;
213#endif
214 constexpr index_t KThdPerWave = WaveSize / KRepeatInWave; // threads cnt in K dim
215 constexpr index_t KWavePerBlk = 1;
216 constexpr index_t KRepeat = 1;
217 static_assert(TileShape::flatKPerWarp == KThdPerWave * KBPerLoad, "wrong");
218
219 constexpr index_t NBPerLoad = 1;
220 constexpr index_t NThdPerWave = 1;
221 constexpr index_t NWavePerBlk = TileShape::BlockWarps::at(number<1>{}); // N_Warp
222 constexpr index_t NRepeat = 1;
223
224 constexpr index_t WaveRepeat = WaveNum / TileShape::flatNPerWarp;
230 // wave in blk, // thd in wave
231 // <M, K> // <M, K>
232 tuple<sequence<0, 1, 2>, sequence<0, 1, 2>>, // which direction
234 // <repeat, vec_load>
237 }
238
239 template <typename Problem>
241 {
244 static_assert(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>);
245 constexpr index_t kBlockSize = Problem::kBlockSize;
246 constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
247 constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
248
249 constexpr index_t M1 = Problem::VectorLoadSize / sizeof(ADataType);
250 constexpr index_t M0 = kMPerBlock / M1;
251 constexpr index_t total_pixels = kMPerBlock * kKPerBlock / kBlockSize;
252 static_assert(total_pixels % M1 == 0);
253 constexpr index_t K3 = total_pixels / M1;
254 constexpr index_t kKPack = GetSmemPackA<Problem>();
255 static_assert(kKPack % K3 == 0);
256 constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
257 constexpr index_t warp_size = get_warp_size();
258 if constexpr(warp_size >= (K2 * M0))
259 {
260 constexpr index_t K1 = warp_size / (K2 * M0);
261 constexpr index_t K0 = kBlockSize / warp_size;
262
269 sequence<1, 3>>{});
270 }
271 else
272 {
273 constexpr index_t K1 = (K2 * M0) / get_warp_size();
274 constexpr index_t K2_m = K2 / K1;
275 constexpr index_t K0 = kBlockSize / get_warp_size() / K1;
276 static_assert(kKPerBlock == K0 * K1 * K2_m * K3);
283 sequence<1, 3>>{});
284 }
285 }
286
287 template <typename Problem>
289 {
290 using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
291 using WarpTile = typename Problem::BlockGemmShape::WarpTile;
292 using BTypeToUse =
293 std::conditional_t<std::is_same_v<typename Problem::BDataType, ck_tile::pk_int4_t>,
294 typename Problem::ADataType,
295 typename Problem::BDataType>;
296 using WarpGemm = WarpGemmDispatcher<typename Problem::ADataType,
297 BTypeToUse,
298 typename Problem::CDataType,
299 WarpTile::at(I0),
300 WarpTile::at(I1),
301 WarpTile::at(I2),
302 Problem::TransposeC>;
303
304 using BlockWeightPreshufflePolicy =
305 BlockWeightPreshuffleASmemBSmemCRegV1CustomPolicy<typename Problem::ADataType,
306 typename Problem::BDataType,
307 typename Problem::CDataType,
308 BlockWarps,
309 WarpGemm>;
311 }
312
324 template <typename Problem>
325 CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeC()
326 {
328 using WG_ = typename BlockGemm::WG;
329
330 constexpr bool TransposeC = Problem::TransposeC;
331 using CLayout = typename Problem::CLayout;
332 using CWarpDstr = typename WG_::CWarpDstr;
333
334 // N is contiguous dimension
335 if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
336 {
337 if constexpr(TransposeC)
338 {
339 // In this case each thread has multiple consecutive elements in
340 // N dimension, however consecutive threads' elements have stride.
341 constexpr index_t NDimY = CWarpDstr::NDimY;
342 constexpr auto c_warp_y_lengths =
343 CWarpDstr{}.get_ys_to_d_descriptor().get_lengths();
344 static_assert(WG_::WarpGemmAttribute::Impl::kCM1PerLane ==
345 c_warp_y_lengths.get(number<NDimY - 1>{}));
346 return c_warp_y_lengths.get(number<NDimY - 1>{});
347 }
348 else
349 {
350 // In this case each thread has just a single item in Ndim
351 return WG_::WarpGemmAttribute::Impl::kCNLane / WG_::kN;
352 }
353 }
354 // M is contiguous dimension
355 else if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::ColumnMajor>)
356 {
357 if constexpr(TransposeC)
358 {
359 // In this case each thread has just a single item in Mdim
360 return WG_::WarpGemmAttribute::Impl::kCNLane / WG_::kN;
361 }
362 else
363 {
364 // In this case each thread has multiple consecutive elements in
365 // M dimension, however consecutive threads' elements have stride.
366 constexpr index_t NDimY = CWarpDstr::NDimY;
367 constexpr auto c_warp_y_lengths =
368 CWarpDstr{}.get_ys_to_d_descriptor().get_lengths();
369 static_assert(WG_::WarpGemmAttribute::Impl::kCM1PerLane ==
370 c_warp_y_lengths.get(number<NDimY - 1>{}));
371 return c_warp_y_lengths.get(number<NDimY - 1>{});
372 }
373 }
374 else
375 {
376 static_assert(false, "Unsupported CLayout!");
377 }
378 }
379};
380
381} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
typename impl::WarpGemmDispatcher< AType, BType, AccType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity, AttrNumAccess >::Type WarpGemmDispatcher
Definition warp_gemm_dispatcher.hpp:182
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition coordinate_transform.hpp:1558
CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
Definition arch.hpp:63
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition tile/core/tensor/tensor_descriptor.hpp:274
CK_TILE_HOST_DEVICE constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition coordinate_transform.hpp:1615
CK_TILE_HOST_DEVICE constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldTopIdss, NewUpperDimensionNewTopIdss)
Definition tile/core/tensor/tensor_descriptor.hpp:203
CK_TILE_HOST_DEVICE constexpr auto make_unmerge_transform(const UpLengths &up_lengths, bool_constant< Use24BitIntegerCalculation >=bool_constant< false >{})
Definition coordinate_transform.hpp:1622
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_HOST_DEVICE constexpr auto make_xor_transform(const LowLengths &low_lengths)
Definition coordinate_transform.hpp:1662
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition tile_distribution.hpp:480
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
Definition block_wp_asmem_bsmem_creg_v1_custom_policy.hpp:18
Definition block_wp_asmem_bsmem_creg_v1.hpp:16
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:34
static constexpr auto I1
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:50
static constexpr auto I2
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:51
static constexpr auto I0
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:49
Definition wp_pipeline_agmem_bgmem_creg_base_policy.hpp:13
static CK_TILE_HOST_DEVICE constexpr auto MakeADramTileDistribution()
Definition wp_pipeline_agmem_bgmem_creg_base_policy.hpp:109
static CK_TILE_HOST_DEVICE constexpr auto GetSmemPackA()
Definition wp_pipeline_agmem_bgmem_creg_base_policy.hpp:83
static CK_TILE_HOST_DEVICE constexpr auto MakeShuffledARegBlockDistribution()
Definition wp_pipeline_agmem_bgmem_creg_base_policy.hpp:240
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSizeA()
Definition wp_pipeline_agmem_bgmem_creg_base_policy.hpp:67
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSize()
Definition wp_pipeline_agmem_bgmem_creg_base_policy.hpp:75
static CK_TILE_HOST_DEVICE constexpr auto GetKBPerLoad()
Definition wp_pipeline_agmem_bgmem_creg_base_policy.hpp:89
UniversalGemmBasePolicy< UniversalWeightPreshufflePipelineAgBgCrPolicy > BasePolicy
Definition wp_pipeline_agmem_bgmem_creg_base_policy.hpp:14
static CK_TILE_HOST_DEVICE constexpr auto GetVectorSizeC()
Get the vector store size for C tensor.
Definition wp_pipeline_agmem_bgmem_creg_base_policy.hpp:325
static CK_TILE_HOST_DEVICE constexpr auto GetBlockWeightPreshuffle()
Definition wp_pipeline_agmem_bgmem_creg_base_policy.hpp:288
static CK_TILE_HOST_DEVICE constexpr auto MakeALdsBlockDescriptor()
Definition wp_pipeline_agmem_bgmem_creg_base_policy.hpp:18
static CK_TILE_DEVICE constexpr auto MakeBFlatDramTileDistribution()
Definition wp_pipeline_agmem_bgmem_creg_base_policy.hpp:200
Definition tile/core/container/sequence.hpp:49
Definition tile_distribution_encoding.hpp:26
Definition tile/core/container/tuple.hpp:192