gemm_aquant_pipeline_ag_bg_cr_mem.hpp Source File

gemm_aquant_pipeline_ag_bg_cr_mem.hpp Source File#

Composable Kernel: gemm_aquant_pipeline_ag_bg_cr_mem.hpp Source File
gemm_aquant_pipeline_ag_bg_cr_mem.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <string>
7#include <sstream>
8
9#include "ck_tile/core.hpp"
15
16namespace ck_tile {
17
18template <typename Problem>
20{
22 {
24 {
25 return TailNumber::Even;
26 }
27 else
28 {
29 return TailNumber::Odd;
30 }
31 }
32 template <typename RunFunction>
33 CK_TILE_HOST_DEVICE static auto
34 TailHandler(const RunFunction& run_func, bool has_hot_loop, TailNumber tail_number)
35 {
36 if(has_hot_loop)
37 {
38 if(tail_number == ck_tile::TailNumber::Odd)
39 {
40 return run_func(
43 }
44 else if(tail_number == ck_tile::TailNumber::Even)
45 {
46 return run_func(
49 }
50 else
51 {
52 throw std::runtime_error("Unsupported tail number for this operation !!!");
53 }
54 }
55 else
56 {
57
58 if(tail_number == ck_tile::TailNumber::Odd)
59 {
60 return run_func(
63 }
64 else if(tail_number == ck_tile::TailNumber::Even)
65 {
66 return run_func(
69 }
70 else
71 {
72 throw std::runtime_error("Unsupported tail number for this operation !!!");
73 }
74 }
75 }
76};
77
78template <typename Problem, typename Policy = GemmAQuantPipelineAgBgCrDefaultPolicy>
80{
83
90
91 static_assert(QuantGroupSize::kM == 1, "no block for M supported yet!");
92 static_assert(QuantGroupSize::kN == 1, "only M/K blocks for AQuant kernel!");
93
94 using I0 = number<0>;
95 using I1 = number<1>;
96 using I2 = number<2>;
97
98 static constexpr index_t APackedSize =
100 static constexpr index_t BPackedSize =
102
103 static constexpr index_t AQPackedSize =
105
110
112
113 static constexpr index_t BlockSize = Problem::kBlockSize;
114 static constexpr index_t MPerBlock = BlockGemmShape::kM;
115 static constexpr index_t NPerBlock = BlockGemmShape::kN;
116 static constexpr index_t KPerBlock = BlockGemmShape::kK;
117 static constexpr index_t KPerBlockAQ = BlockGemmShape::kK / QuantGroupSize::kK;
118
119 static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA<Problem>(); }
120 static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB<Problem>(); }
121 static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC<Problem>(); }
122 static constexpr index_t GetVectorSizeAQ()
123 {
124 return Policy::template GetVectorSizeAQ<Problem>();
125 }
126
127 static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA<Problem>(); }
128 static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB<Problem>(); }
129
130 static constexpr bool kPadM = Problem::kPadM;
131 static constexpr bool kPadN = Problem::kPadN;
132 static constexpr bool kPadK = Problem::kPadK;
133
134 static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
135 static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
136
137 static constexpr bool HasHotLoop = Problem::HasHotLoop;
138 static constexpr auto TailNum = Problem::TailNum;
139 static constexpr auto Scheduler = Problem::Scheduler;
140
142
143 [[nodiscard]] CK_TILE_HOST static const std::string GetName()
144 {
145 // clang-format off
146 constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{});
147 constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{});
148 return concat('_', "aquant_pipeline_AgBgCrMem",
150 BlockSize,
151 concat('x', WaveNumM, WaveNumN),
152 concat('x', BlockGemm::WarpGemm::kM, BlockGemm::WarpGemm::kN, BlockGemm::WarpGemm::kK),
153 concat('x', kPadM, kPadN, kPadK), QuantGroupSize::GetName(),
154 Scheduler == GemmPipelineScheduler::Interwave ? "interwave" : "intrawave"); // else Intrawave
155 // clang-format on
156 }
157
159 {
160 return Policy::template GetSmemSize<Problem>();
161 }
162
163 CK_TILE_HOST static std::string Print()
164 {
165 constexpr index_t MPerXDL = BlockGemm::WarpGemm::kM;
166 constexpr index_t NPerXDL = BlockGemm::WarpGemm::kN;
167 constexpr index_t KPerXDL = BlockGemm::WarpGemm::WarpGemmAttribute::Impl::kK;
168
169 constexpr index_t WaveSize = 64;
170 constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{});
171 constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{});
172
173 constexpr index_t A_LDS_Read_Width = GetSmemPackA();
174 constexpr index_t B_LDS_Read_Width = GetSmemPackB();
175
176 constexpr index_t A_LDS_Write_Width = GetSmemPackA();
177 constexpr index_t B_LDS_Write_Width = GetSmemPackB();
178
179 constexpr index_t A_Buffer_Load_Inst_Num =
181 constexpr index_t B_Buffer_Load_Inst_Num =
183 constexpr index_t AQ_Buffer_Load_Inst_Num =
185
186 constexpr index_t A_LDS_Write_Inst_Num =
187 MPerBlock * KPerBlock / (BlockSize * A_LDS_Write_Width);
188 constexpr index_t B_LDS_Write_Inst_Num =
189 NPerBlock * KPerBlock / (BlockSize * B_LDS_Write_Width);
190
191 constexpr index_t A_LDS_Read_Inst_Num =
192 WaveNumN * MPerBlock * KPerBlock / (BlockSize * A_LDS_Read_Width);
193 constexpr index_t B_LDS_Read_Inst_Num =
194 WaveNumM * NPerBlock * KPerBlock / (BlockSize * B_LDS_Read_Width);
195
196 constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock /
197 (BlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL);
198
199 auto str = std::stringstream{};
200
201 str << "A/B vector size: " << GetVectorSizeA() << ", " << GetVectorSizeB() << ", "
202 << "AQ vector size: " << GetVectorSizeAQ() << "\n"
203 << "A/B LDS read/write width: " << A_LDS_Read_Width << ", " << B_LDS_Read_Width << "\n"
204 << "A/B buffer load inst: " << A_Buffer_Load_Inst_Num << ", " << B_Buffer_Load_Inst_Num
205 << ", " << "AQ buffer load inst: " << AQ_Buffer_Load_Inst_Num << "\n"
206 << "A/B LDS write inst: " << A_LDS_Write_Inst_Num << ", " << B_LDS_Write_Inst_Num
207 << "\n"
208 << "A/B LDS read inst: " << A_LDS_Read_Inst_Num << ", " << B_LDS_Read_Inst_Num << "\n"
209 << "C MFMA inst: " << C_MFMA_Inst_Num << "\n"
210 << "QuantGroupSize: " << QuantGroupSize::GetName() << "\n"
211 << "KPack: " << BlockGemm::Traits::KPack << "\n"
212 << "PrefetchStages: " << PrefetchStages << "\n";
213 return str.str();
214 }
215
216 template <GemmPipelineScheduler Scheduler>
218 {
219 };
220
221 template <>
223 {
225
226 template <bool HasHotLoop,
228 typename ADramBlockWindowTmp,
229 typename BDramBlockWindowTmp,
230 typename AQDramBlockWindowTmp,
231 typename AElementFunction,
232 typename BElementFunction>
233 CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
234 const AElementFunction& a_element_func,
235 const BDramBlockWindowTmp& b_dram_block_window_tmp,
236 const BElementFunction& b_element_func,
237 const AQDramBlockWindowTmp& aq_dram_block_window_tmp,
238 index_t m,
239 index_t num_loop,
240 void* p_smem) const
241 {
242 (void)m; // unused variable
243 static_assert(
244 std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
245 std::is_same_v<BDataType,
247 std::is_same_v<AQDataType,
249 "A/B/AQ Dram block window should have the same data type as appropriate "
250 "([A|B|AQ]DataType) defined in Problem definition!");
251
252 constexpr bool is_a_col_major =
253 std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
254 constexpr bool is_aq_col_major =
255 std::is_same_v<AQLayout, tensor_layout::gemm::ColumnMajor>;
256 constexpr bool is_b_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
257
258 static_assert(!is_aq_col_major, "Aq must be row major (col major not supported yet)");
259 static_assert(MPerBlock == AQDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
260 KPerBlockAQ == AQDramBlockWindowTmp{}.get_window_lengths()[I1{}],
261 "Aq block window has incorrect lengths for defined AqLayout!");
262
263 static_assert(is_a_col_major
264 ? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
265 MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}])
266 : (MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
267 KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]),
268 "A block window has incorrect lengths for defined ALayout!");
269 static_assert(is_b_row_major
270 ? (KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
271 NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}])
272 : (NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
273 KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]),
274 "B block window has incorrect lengths for defined BLayout!");
275
276 // A/B tiles in LDS - using the same approach as regular gemm pipeline
277 auto ab_lds_blocks = Base::GetABLdsTensorViews(p_smem);
278 auto& a_lds_block = ab_lds_blocks.at(I0{});
279 auto& b_lds_block = ab_lds_blocks.at(I1{});
280
281 // Tile distribution for load from lds
282 constexpr auto a_lds_load_tile_distr =
283 make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode());
284 constexpr auto b_lds_load_tile_distr =
285 make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode());
286
287 auto a_windows =
288 Base::GetAWindows(a_dram_block_window_tmp, a_lds_block, a_lds_load_tile_distr);
289 auto& a_copy_dram_window = a_windows.at(I0{});
290 auto& a_copy_lds_window = a_windows.at(I1{});
291 auto& a_lds_gemm_window = a_windows.at(I2{});
292
293 auto b_windows =
294 Base::GetBWindows(b_dram_block_window_tmp, b_lds_block, b_lds_load_tile_distr);
295 auto& b_copy_dram_window = b_windows.at(I0{});
296 auto& b_copy_lds_window = b_windows.at(I1{});
297 auto& b_lds_gemm_window = b_windows.at(I2{});
298
299 auto aq_copy_dram_window = Base::GetAQDramLoadWindow(aq_dram_block_window_tmp);
300
301 auto block_gemm = BlockGemm();
302 auto c_block_tile = block_gemm.MakeCBlockTile();
303
304 using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution());
305 using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution());
306 using AQBlockTileDistr = decltype(aq_copy_dram_window.get_tile_distribution());
307
308 using ABlockTile =
309 decltype(make_static_distributed_tensor<ADataType>(ABlockTileDistr{}));
310 using BBlockTile =
311 decltype(make_static_distributed_tensor<BDataType>(BBlockTileDistr{}));
312 using AQBlockTile =
313 decltype(make_static_distributed_tensor<AQDataType>(AQBlockTileDistr{}));
314
315 // Memory pipeline uses multiple prefetch stages
319
320 using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
321 using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
322 using AQDramTileWindowStep = typename AQDramBlockWindowTmp::BottomTensorIndex;
323
324 constexpr ADramTileWindowStep a_dram_tile_window_step =
325 is_a_col_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
326 constexpr BDramTileWindowStep b_dram_tile_window_step =
327 is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
328 constexpr AQDramTileWindowStep aq_dram_tile_window_step =
329 is_aq_col_major ? make_array(KPerBlockAQ, 0) : make_array(0, KPerBlockAQ);
330
331 // Global prefetch initialization - DRAM to VGPRs
333 a_block_tiles.get(I0{}), a_copy_dram_window, a_dram_tile_window_step);
335 b_block_tiles.get(I0{}), b_copy_dram_window, b_dram_tile_window_step);
337 aq_block_tiles.get(I0{}), aq_copy_dram_window, aq_dram_tile_window_step);
338
339 tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
340
341 // LDS prefill - VGPRs to LDS
342 if constexpr(is_a_col_major)
343 {
345 Policy::template MakeShuffled2DStaticTileDistribution<Problem>());
346 transpose_tile2d(a_shuffle_tmp, a_block_tiles.get(I0{}));
347 Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
348 }
349 else
350 {
351 Base::LocalPrefill(a_copy_lds_window, a_block_tiles.get(I0{}), a_element_func);
352 }
353 if constexpr(is_b_row_major)
354 {
356 Policy::template MakeShuffled2DStaticTileDistribution<Problem>());
357 transpose_tile2d(b_shuffle_tmp, b_block_tiles.get(I0{}));
358 Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
359 }
360 else
361 {
362 Base::LocalPrefill(b_copy_lds_window, b_block_tiles.get(I0{}), b_element_func);
363 }
364 // Additional prefetching for memory pipeline - DRAM to VGPRs
365 static_for<1, PrefetchStages, 1>{}([&](auto prefetch_idx) {
366 Base::GlobalPrefetch(a_block_tiles.get(number<prefetch_idx>{}),
367 a_copy_dram_window,
368 a_dram_tile_window_step);
369 Base::GlobalPrefetch(b_block_tiles.get(number<prefetch_idx>{}),
370 b_copy_dram_window,
371 b_dram_tile_window_step);
372 Base::GlobalPrefetch(aq_block_tiles.get(number<prefetch_idx>{}),
373 aq_copy_dram_window,
374 aq_dram_tile_window_step);
375 });
376
377 // Main hot loop for memory pipeline
378 if constexpr(HasHotLoop)
379 {
380 index_t i = 0;
381 do
382 {
383 static_for<0, PrefetchStages, 1>{}([&](auto prefetch_idx) {
385 block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
386 block_gemm(c_block_tile,
387 aq_block_tiles.get(number<prefetch_idx>{}),
388 a_lds_gemm_window,
389 b_lds_gemm_window);
391 // Prepare next iteration data
392 if constexpr(is_a_col_major)
393 {
395 Policy::template MakeShuffledARegTileDistribution<Problem>());
397 a_shuffle_tmp,
398 a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}));
399 Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
400 }
401 else
402 {
404 a_copy_lds_window,
405 a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}),
406 a_element_func);
407 }
408 if constexpr(is_b_row_major)
409 {
411 Policy::template MakeShuffledBRegTileDistribution<Problem>());
413 b_shuffle_tmp,
414 b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}));
415 Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
416 }
417 else
418 {
420 b_copy_lds_window,
421 b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}),
422 b_element_func);
423 }
424
425 Base::GlobalPrefetch(a_block_tiles.get(number<prefetch_idx>{}),
426 a_copy_dram_window,
427 a_dram_tile_window_step);
428 Base::GlobalPrefetch(b_block_tiles.get(number<prefetch_idx>{}),
429 b_copy_dram_window,
430 b_dram_tile_window_step);
431 Base::GlobalPrefetch(aq_block_tiles.get(number<prefetch_idx>{}),
432 aq_copy_dram_window,
433 aq_dram_tile_window_step);
434 });
435
436 i += PrefetchStages;
437 } while(i < (num_loop - PrefetchStages));
438 }
439
440 // Tail handling
442 block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
443 block_gemm(
444 c_block_tile, aq_block_tiles.get(I0{}), a_lds_gemm_window, b_lds_gemm_window);
445
446 if constexpr(TailNum == TailNumber::Even)
447 {
448
449 Base::LocalPrefill(a_copy_lds_window, a_block_tiles.get(I1{}), a_element_func);
450 Base::LocalPrefill(b_copy_lds_window, b_block_tiles.get(I1{}), b_element_func);
452 block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
453 block_gemm(
454 c_block_tile, aq_block_tiles.get(I1{}), a_lds_gemm_window, b_lds_gemm_window);
455 }
456 return c_block_tile;
457 }
458 };
459
460 template <typename ADramBlockWindowTmp,
461 typename BDramBlockWindowTmp,
462 typename AQDramBlockWindowTmp>
463 CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
464 const BDramBlockWindowTmp& b_dram_block_window_tmp,
465 const AQDramBlockWindowTmp& aq_dram_block_window_tmp,
466 index_t m,
467 index_t num_loop,
468 void* p_smem) const
469 {
470
472 .template operator()<HasHotLoop, TailNum>(
473 a_dram_block_window_tmp,
474 [](const ADataType& a) { return a; },
475 b_dram_block_window_tmp,
476 [](const BDataType& b) { return b; },
477 aq_dram_block_window_tmp,
478 m,
479 num_loop,
480 p_smem);
481 }
482};
483
484} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST
Definition config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
TailNumber
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:21
@ Even
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:24
@ Odd
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:23
constant< b > bool_constant
Definition tile/core/numeric/integral_constant.hpp:43
CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc &inout_element_func, InOutDstrTensors &... inout_dstr_tensors)
Definition tile_elementwise.hpp:23
CK_TILE_DEVICE void block_sync_lds()
Definition arch.hpp:282
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition concat.hpp:43
CK_TILE_DEVICE void transpose_tile2d(OutTensor &out, const InTensor &in)
Definition transpose_tile.hpp:195
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution &)
Definition static_distributed_tensor.hpp:142
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition tile_distribution.hpp:480
typename impl::tuple_array_impl< T, N >::type tuple_array
Definition tile/core/container/tuple.hpp:28
GemmPipelineScheduler
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:14
@ Interwave
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:17
CK_TILE_HOST_DEVICE constexpr details::return_type< D, Ts... > make_array(Ts &&... ts)
Definition tile/core/container/array.hpp:242
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition pointer.h:1517
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp &a_dram_block_window_tmp, const AElementFunction &a_element_func, const BDramBlockWindowTmp &b_dram_block_window_tmp, const BElementFunction &b_element_func, const AQDramBlockWindowTmp &aq_dram_block_window_tmp, index_t m, index_t num_loop, void *p_smem) const
Definition gemm_aquant_pipeline_ag_bg_cr_mem.hpp:233
PipelineImplBase Base
Definition gemm_aquant_pipeline_ag_bg_cr_mem.hpp:224
Definition gemm_aquant_pipeline_ag_bg_cr_mem.hpp:218
Definition gemm_aquant_pipeline_ag_bg_cr_mem.hpp:80
remove_cvref_t< typename Problem::BDataType > BDataType
Definition gemm_aquant_pipeline_ag_bg_cr_mem.hpp:86
remove_cvref_t< typename Problem::BlockGemmShape > BlockGemmShape
Definition gemm_aquant_pipeline_ag_bg_cr_mem.hpp:88
number< 1 > I1
Definition gemm_aquant_pipeline_ag_bg_cr_mem.hpp:95
remove_cvref_t< decltype(Policy::template GetBlockGemm< Problem >())> BlockGemm
Definition gemm_aquant_pipeline_ag_bg_cr_mem.hpp:111
remove_cvref_t< typename Problem::AQDataType > AQDataType
Definition gemm_aquant_pipeline_ag_bg_cr_mem.hpp:85
remove_cvref_t< typename Problem::ALayout > ALayout
Definition gemm_aquant_pipeline_ag_bg_cr_mem.hpp:106
remove_cvref_t< typename Problem::CDataType > CDataType
Definition gemm_aquant_pipeline_ag_bg_cr_mem.hpp:87
static constexpr bool kPadN
Definition gemm_aquant_pipeline_ag_bg_cr_mem.hpp:131
remove_cvref_t< typename Problem::ADataType > ADataType
Definition gemm_aquant_pipeline_ag_bg_cr_mem.hpp:84
static constexpr index_t GetSmemPackB()
Definition gemm_aquant_pipeline_ag_bg_cr_mem.hpp:128
remove_cvref_t< typename Problem::BLayout > BLayout
Definition gemm_aquant_pipeline_ag_bg_cr_mem.hpp:108
remove_cvref_t< typename Problem::CLayout > CLayout
Definition gemm_aquant_pipeline_ag_bg_cr_mem.hpp:109
remove_cvref_t< typename Problem::AQLayout > AQLayout
Definition gemm_aquant_pipeline_ag_bg_cr_mem.hpp:107
remove_cvref_t< typename Problem::QuantGroupSize > QuantGroupSize
Definition gemm_aquant_pipeline_ag_bg_cr_mem.hpp:89
static constexpr index_t PrefetchStages
Definition gemm_pipeline_ag_bg_cr_mem.hpp:46
static CK_TILE_HOST std::string Print()
Definition gemm_aquant_pipeline_ag_bg_cr_mem.hpp:163
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp &a_dram_block_window_tmp, const BDramBlockWindowTmp &b_dram_block_window_tmp, const AQDramBlockWindowTmp &aq_dram_block_window_tmp, index_t m, index_t num_loop, void *p_smem) const
Definition gemm_aquant_pipeline_ag_bg_cr_mem.hpp:463
static constexpr index_t MPerBlock
Definition gemm_aquant_pipeline_ag_bg_cr_mem.hpp:114
static constexpr index_t GetVectorSizeC()
Definition gemm_aquant_pipeline_ag_bg_cr_mem.hpp:121
static constexpr bool HasHotLoop
Definition gemm_aquant_pipeline_ag_bg_cr_mem.hpp:137
BaseGemmPipelineAgBgCrMem< Problem > Base
Definition gemm_aquant_pipeline_ag_bg_cr_mem.hpp:81
GemmAQuantPipelineAgBgCrImplBase< Problem, Policy > PipelineImplBase
Definition gemm_aquant_pipeline_ag_bg_cr_mem.hpp:82
static constexpr auto Scheduler
Definition gemm_aquant_pipeline_ag_bg_cr_mem.hpp:139
static constexpr index_t NPerBlock
Definition gemm_aquant_pipeline_ag_bg_cr_mem.hpp:115
static CK_TILE_HOST const std::string GetName()
Definition gemm_aquant_pipeline_ag_bg_cr_mem.hpp:143
static constexpr index_t BlockSize
Definition gemm_aquant_pipeline_ag_bg_cr_mem.hpp:113
static constexpr index_t AQPackedSize
Definition gemm_aquant_pipeline_ag_bg_cr_mem.hpp:103
static constexpr index_t GetVectorSizeB()
Definition gemm_aquant_pipeline_ag_bg_cr_mem.hpp:120
static constexpr index_t GetSmemPackA()
Definition gemm_aquant_pipeline_ag_bg_cr_mem.hpp:127
static constexpr index_t KPerBlockAQ
Definition gemm_aquant_pipeline_ag_bg_cr_mem.hpp:117
static constexpr bool DoubleSmemBuffer
Definition gemm_aquant_pipeline_ag_bg_cr_mem.hpp:134
number< 0 > I0
Definition gemm_aquant_pipeline_ag_bg_cr_mem.hpp:94
static constexpr bool PreshuffleQuant
Definition gemm_aquant_pipeline_ag_bg_cr_mem.hpp:135
static constexpr index_t GetVectorSizeA()
Definition gemm_aquant_pipeline_ag_bg_cr_mem.hpp:119
number< 2 > I2
Definition gemm_aquant_pipeline_ag_bg_cr_mem.hpp:96
static constexpr index_t GetVectorSizeAQ()
Definition gemm_aquant_pipeline_ag_bg_cr_mem.hpp:122
static constexpr index_t KPerBlock
Definition gemm_aquant_pipeline_ag_bg_cr_mem.hpp:116
static constexpr bool kPadM
Definition gemm_aquant_pipeline_ag_bg_cr_mem.hpp:130
static constexpr index_t APackedSize
Definition gemm_aquant_pipeline_ag_bg_cr_mem.hpp:98
static constexpr index_t BPackedSize
Definition gemm_aquant_pipeline_ag_bg_cr_mem.hpp:100
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSize()
Definition gemm_aquant_pipeline_ag_bg_cr_mem.hpp:158
static constexpr bool kPadK
Definition gemm_aquant_pipeline_ag_bg_cr_mem.hpp:132
static constexpr auto TailNum
Definition gemm_aquant_pipeline_ag_bg_cr_mem.hpp:138
Definition gemm_aquant_pipeline_ag_bg_cr_mem.hpp:20
static CK_TILE_HOST_DEVICE auto TailHandler(const RunFunction &run_func, bool has_hot_loop, TailNumber tail_number)
Definition gemm_aquant_pipeline_ag_bg_cr_mem.hpp:34
static CK_TILE_HOST_DEVICE constexpr TailNumber GetBlockLoopTailNum(index_t num_loop)
Definition gemm_aquant_pipeline_ag_bg_cr_mem.hpp:21
static constexpr index_t PrefetchStages
Definition gemm_pipeline_ag_bg_cr_comp_v3.hpp:19
Definition gemm_pipeline_ag_bg_cr_mem.hpp:19
static constexpr index_t PrefetchStages
Definition gemm_pipeline_ag_bg_cr_mem.hpp:46
Definition gemm_aquant_pipeline_ag_bg_cr_base.hpp:14
static constexpr index_t NPerBlock
Definition gemm_aquant_pipeline_ag_bg_cr_base.hpp:26
static constexpr index_t KPerBlock
Definition gemm_aquant_pipeline_ag_bg_cr_base.hpp:27
typename Base::BDataType BDataType
Definition gemm_aquant_pipeline_ag_bg_cr_base.hpp:18
CK_TILE_DEVICE constexpr auto GetAQDramLoadWindow(const AQDramBlockWindowTmp &aq_dram_block_window_tmp) const
Definition gemm_aquant_pipeline_ag_bg_cr_base.hpp:37
static constexpr index_t KPerBlockAQ
Definition gemm_aquant_pipeline_ag_bg_cr_base.hpp:29
static constexpr index_t MPerBlock
Definition gemm_aquant_pipeline_ag_bg_cr_base.hpp:25
CK_TILE_DEVICE constexpr auto GetBWindows(const BDramBlockWindowTmp &b_dram_block_window_tmp, const BLdsTensorView &b_lds_block_view, const BLdsLoadTileDistr &, const array< index_t, 2 > &offset={0, 0}) const
Definition gemm_pipeline_ag_bg_cr_base.hpp:225
CK_TILE_DEVICE auto GetABLdsTensorViews(void *p_smem) const
Definition gemm_pipeline_ag_bg_cr_base.hpp:83
CK_TILE_DEVICE void LocalPrefill(DstTileWindow &lds_tile_window, const SrcBlockTile &src_block_tile, const ElementFunction &element_func) const
Definition gemm_pipeline_ag_bg_cr_base.hpp:57
CK_TILE_DEVICE constexpr auto GetAWindows(const ADramBlockWindowTmp &a_dram_block_window_tmp, const ALdsTensorView &a_lds_block_view, const ALdsLoadTileDistr &, const array< index_t, 2 > &offset={0, 0}) const
Definition gemm_pipeline_ag_bg_cr_base.hpp:190
CK_TILE_DEVICE void GlobalPrefetch(DstBlockTile &dst_block_tile, SrcTileWindow &dram_tile_window, const DramTileWindowStep &dram_tile_window_step) const
Definition gemm_pipeline_ag_bg_cr_base.hpp:39
Definition tile/core/numeric/integral_constant.hpp:30
Definition tile/core/numeric/numeric.hpp:81
Definition tile/core/utility/functional.hpp:43