blockwise_gemm_pipeline_xdlops_v3_b_scale.hpp Source File

blockwise_gemm_pipeline_xdlops_v3_b_scale.hpp Source File#

Composable Kernel: blockwise_gemm_pipeline_xdlops_v3_b_scale.hpp Source File
blockwise_gemm_pipeline_xdlops_v3_b_scale.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
7
8namespace ck {
9
10// Compute optimized pipeline
11// GlobalPrefetchStages: 2
12// LocalPreFillStages: 1
13// LocalPreFetchStages: 1
14// LocalSharedMemoryBuffer: 1
15
16template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
17 index_t BlockSize,
18 typename ADataType,
19 typename BDataType,
20 typename ComputeDataType,
21 typename AccDataType,
22 typename ATileDesc,
23 typename BTileDesc,
24 typename AMmaTileDesc,
25 typename BMmaTileDesc,
26 index_t ABlockTransferSrcScalarPerVector,
27 index_t BBlockTransferSrcScalarPerVector,
28 index_t MPerBlock,
29 index_t NPerBlock,
30 index_t KPerBlock,
31 index_t MPerXDL,
32 index_t NPerXDL,
33 index_t MRepeat,
34 index_t NRepeat,
35 index_t KPacks>
39
40template <index_t BlockSize,
41 typename ADataType,
42 typename BDataType,
43 typename ComputeDataType,
44 typename AccDataType,
45 typename ATileDesc,
46 typename BTileDesc,
47 typename AMmaTileDesc,
48 typename BMmaTileDesc,
49 index_t ABlockTransferSrcScalarPerVector,
50 index_t BBlockTransferSrcScalarPerVector,
51 index_t MPerBlock,
52 index_t NPerBlock,
53 index_t KPerBlock,
54 index_t MPerXDL,
55 index_t NPerXDL,
56 index_t MRepeat,
57 index_t NRepeat,
58 index_t KPack
59 // ,bool TransposeC //disable transposec right now...
60 >
62 BlockSize,
63 ADataType,
64 BDataType,
65 ComputeDataType,
66 AccDataType,
67 ATileDesc,
68 BTileDesc,
69 AMmaTileDesc,
70 BMmaTileDesc,
71 ABlockTransferSrcScalarPerVector,
72 BBlockTransferSrcScalarPerVector,
73 MPerBlock,
74 NPerBlock,
75 KPerBlock,
76 MPerXDL,
77 NPerXDL,
78 MRepeat,
79 NRepeat,
80 KPack>
82 ADataType,
83 BDataType,
84 ComputeDataType,
85 AccDataType,
86 ATileDesc,
87 BTileDesc,
88 AMmaTileDesc,
89 BMmaTileDesc,
90 ABlockTransferSrcScalarPerVector,
91 BBlockTransferSrcScalarPerVector,
92 MPerBlock,
93 NPerBlock,
94 KPerBlock,
95 MPerXDL,
96 NPerXDL,
97 MRepeat,
98 NRepeat,
99 KPack>
100
101{
103 ADataType,
104 BDataType,
105 ComputeDataType,
106 AccDataType,
107 ATileDesc,
108 BTileDesc,
109 AMmaTileDesc,
110 BMmaTileDesc,
111 ABlockTransferSrcScalarPerVector,
112 BBlockTransferSrcScalarPerVector,
113 MPerBlock,
114 NPerBlock,
115 KPerBlock,
116 MPerXDL,
117 NPerXDL,
118 MRepeat,
119 NRepeat,
120 KPack>;
121 using Base::I0;
122 using Base::I1;
123 using Base::KRepeat;
124 using Base::xdlops_gemm;
125 using typename Base::HotLoopInstList;
126
138
141
142 using Base::AMmaKStride;
143 using Base::BMmaKStride;
144
146
147 static constexpr index_t PrefetchStages = 2;
148 static constexpr index_t PrefillStages = 1;
149 static constexpr index_t GlobalBufferNum = 1;
150
151 __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
152 {
153 return num_loop > PrefetchStages;
154 }
155
156 __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
157 {
158 ignore = num_loop;
159 return TailNumber::Full;
160 }
161
162 __device__ static constexpr auto HotLoopScheduler()
163 {
164 // A/B split schedule
165 // compiler is likely to use ds_read2 when instruction width smaller than 16bytes
166 constexpr auto num_ds_read_inst_a =
167 HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16
170 constexpr auto num_ds_read_inst_b =
171 HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16
174
175 constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num;
176 constexpr auto num_ds_write_inst_b = HotLoopInstList::B_LDS_Write_Inst_Num;
177
178 constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
179 constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num;
180
181 constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num;
182
183 constexpr auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle;
184 constexpr auto ds_read_a_issue_cycle =
185 HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4;
186 constexpr auto ds_read_b_issue_cycle =
187 HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 ? 8 : 4;
188 constexpr auto ds_read_a_mfma_rate =
189 (mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle);
190 constexpr auto ds_read_b_mfma_rate =
191 (mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle);
192
193 constexpr auto num_dsread_a_mfma =
194 (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate;
195 constexpr auto num_dsread_b_mfma =
196 (num_ds_read_inst_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate;
197
198 // stage 1
199 // Separate this part?
200 // constexpr auto num_mfma_per_ds_read = sizeof(ComputeDataTypeBuf) / sizeof(ADataType) >
201 // sizeof(ComputeDataTypeBuf) / sizeof(BDataType)
202 // ? sizeof(ComputeDataTypeBuf) / sizeof(ADataType)
203 // : sizeof(ComputeDataTypeBuf) / sizeof(BDataType);
204 constexpr auto num_mfma_stage1 = num_mfma_inst - (num_dsread_a_mfma + num_dsread_b_mfma);
205 constexpr auto num_mfma_per_issue =
206 num_mfma_stage1 / (num_buffer_load_inst_a + num_buffer_load_inst_b);
207 constexpr auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a;
208 constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b;
209
211 ignore = i;
212 static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) {
213 ignore = idswrite;
214 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
215 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
216 });
217 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
218 __builtin_amdgcn_sched_group_barrier(
219 0x008, num_mfma_per_issue - num_dswrite_per_issue_a, 0); // MFMA
220 });
222 ignore = i;
223 static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) {
224 ignore = idswrite;
225 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
226 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
227 });
228 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
229 __builtin_amdgcn_sched_group_barrier(
230 0x008, num_mfma_per_issue - num_dswrite_per_issue_b, 0); // MFMA
231 });
232
233 // stage 2
235 if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >=
236 ds_read_a_mfma_rate)
237 {
238 __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
239 }
240 else
241 {
242 __builtin_amdgcn_sched_group_barrier(0x100,
243 num_ds_read_inst_a - (num_dsread_a_mfma - 1) *
244 ds_read_a_mfma_rate,
245 0); // DS read
246 }
247 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
248 });
249
251 if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_mfma_rate) >=
252 ds_read_b_mfma_rate)
253 {
254 __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read
255 }
256 else
257 {
258 __builtin_amdgcn_sched_group_barrier(0x100,
259 num_ds_read_inst_b - (num_dsread_b_mfma - 1) *
260 ds_read_b_mfma_rate,
261 0); // DS read
262 }
263 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
264 });
265 }
266
267 template <bool HasMainLoop,
268 TailNumber TailNum,
269 typename AGridDesc,
270 typename ABlockDesc,
271 typename ABlockTransfer,
272 typename AGridBuffer,
273 typename ABlockBuffer,
274 typename ABlockTransferStep,
275 typename BGridDesc,
276 typename BBlockDesc,
277 typename BBlockTransfer,
278 typename BGridBuffer,
279 typename BBlockBuffer,
280 typename BBlockTransferStep,
281 typename CThreadBuffer,
282 typename BScaleGridBuffer,
283 typename BScaleGridDesc,
284 typename BScaleThreadDesc,
285 typename BScaleThreadTransfer,
286 typename BScaleThreadTransferStep>
287 __device__ void Run(const AGridDesc& a_grid_desc,
288 const ABlockDesc& a_block_desc,
289 ABlockTransfer& a_blockwise_copy,
290 const AGridBuffer& a_grid_buf,
291 ABlockBuffer& a_block_buf,
292 const ABlockTransferStep& a_block_copy_step,
293 const BGridDesc& b_grid_desc,
294 const BBlockDesc& b_block_desc,
295 BBlockTransfer& b_blockwise_copy,
296 const BGridBuffer& b_grid_buf,
297 BBlockBuffer& b_block_buf,
298 const BBlockTransferStep& b_block_copy_step,
299 CThreadBuffer& c_thread_buf,
300 // BScaleThreadCopy
301 const BScaleGridDesc& b_scale_grid_desc,
302 const BScaleThreadDesc& b_scale_thread_desc,
303 BScaleThreadTransfer& b_scale_thread_copy,
304 const BScaleGridBuffer& b_scale_grid_buf,
305 const BScaleThreadTransferStep& b_scale_thread_copy_step,
306 // num loop
307 index_t num_loop,
308 index_t num_loop_per_scale) const
309 {
310 __builtin_amdgcn_sched_barrier(0);
311
313 a_thread_desc_.GetElementSpaceSize());
315 b_thread_desc_.GetElementSpaceSize());
316
317 // B scale buffer
319 b_scale_thread_desc.GetElementSpaceSize());
320
321 // Global prefetch 1
322 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
323 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
324
325 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
326 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
327
328 static_for<0, NRepeat, 1>{}([&](auto n0) {
329 b_scale_thread_copy.Run(b_scale_grid_desc,
330 b_scale_grid_buf,
331 b_scale_thread_desc,
332 make_tuple(n0, I0),
333 b_scale_thread_buf);
334
335 b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
336 b_scale_thread_copy_step.At(Number<0>{}));
337 });
338
339 if(num_loop_per_scale == 1)
340 {
341 b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
342 b_scale_thread_copy_step.At(Number<2>{}));
343 }
344 else
345 {
346 b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
347 b_scale_thread_copy_step.At(Number<1>{}));
348 }
349
350 constexpr auto num_scale_k_block = BScaleThreadDesc{}.GetLength(I1);
351 constexpr auto num_scale_krepeat = KRepeat / num_scale_k_block;
352
353 // Local prefill 1
354 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
355 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
356
357 // Global prefetch 2
358 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
359 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
360
361 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
362 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
363
364 // Initialize C
365 c_thread_buf.Clear();
366
367 // Local prefetch 1
369 static_for<0, KRepeat, 1>{}([&](auto k0) {
370 static_for<0, MRepeat, 1>{}([&](auto m0) {
373 a_block_buf,
375 make_tuple(m0, I0, k0, I0),
376 a_thread_buf);
377 });
378 static_for<0, NRepeat, 1>{}([&](auto n0) {
379 b_thread_copy_.Run(
382 b_block_buf,
383 b_scale_thread_buf[Number<n0 * num_scale_k_block + k0 / num_scale_krepeat>{}],
385 make_tuple(n0, I0, k0, I0),
386 b_thread_buf);
387 });
388 });
389
390 __builtin_amdgcn_sched_barrier(0);
391
392 // main body
393 if constexpr(HasMainLoop)
394 {
395 index_t i = 0;
396 do
397 {
399
400 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
401 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
402
403 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
404 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
405
406 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
407 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
408
409 static_for<0, NRepeat, 1>{}([&](auto n0) {
410 b_scale_thread_copy.Run(b_scale_grid_desc,
411 b_scale_grid_buf,
412 b_scale_thread_desc,
413 make_tuple(n0, I0),
414 b_scale_thread_buf);
415
416 b_scale_thread_copy.MoveSrcSliceWindow(
417 b_scale_grid_desc, b_scale_thread_copy_step.At(Number<0>{}));
418 });
419
420 if((i + 2) % num_loop_per_scale == 0)
421 {
422 b_scale_thread_copy.MoveSrcSliceWindow(
423 b_scale_grid_desc, b_scale_thread_copy_step.At(Number<2>{}));
424 }
425 else
426 {
427 b_scale_thread_copy.MoveSrcSliceWindow(
428 b_scale_grid_desc, b_scale_thread_copy_step.At(Number<1>{}));
429 }
430
431 static_for<0, KRepeat, 1>{}([&](auto k0) {
432 static_for<0, MRepeat, 1>{}([&](auto m0) {
433 static_for<0, NRepeat, 1>{}([&](auto n0) {
436
437 static_for<0, KPack, 1>{}([&](auto ik) {
438 a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
439 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
440 make_tuple(m0, I0, k0, ik))>{}];
441 b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
442 b_thread_buf[Number<b_thread_desc_.CalculateOffset(
443 make_tuple(n0, I0, k0, ik))>{}];
444 });
445
446 using mfma_input_type =
448 xdlops_gemm.K1PerXdlops>::type;
449
450 constexpr index_t c_offset =
451 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
452
453 xdlops_gemm.Run(
454 a_thread_vec.template AsType<mfma_input_type>(),
455 b_thread_vec.template AsType<mfma_input_type>(),
456 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
457 });
458 });
459 });
460
462
463 static_for<0, KRepeat, 1>{}([&](auto k0) {
464 static_for<0, MRepeat, 1>{}([&](auto m0) {
467 a_block_buf,
469 make_tuple(m0, I0, k0, I0),
470 a_thread_buf);
471 });
472 static_for<0, NRepeat, 1>{}([&](auto n0) {
475 b_block_buf,
476 b_scale_thread_buf[Number<n0 * num_scale_k_block +
477 k0 / num_scale_krepeat>{}],
479 make_tuple(n0, I0, k0, I0),
480 b_thread_buf);
481 });
482 });
483
485 __builtin_amdgcn_sched_barrier(0);
486
487 i += 1;
488 } while(i < (num_loop - 1));
489 }
490 // tail
491 if constexpr(TailNum == TailNumber::Full)
492 {
493 static_for<0, KRepeat, 1>{}([&](auto k0) {
494 static_for<0, MRepeat, 1>{}([&](auto m0) {
495 static_for<0, NRepeat, 1>{}([&](auto n0) {
498
499 static_for<0, KPack, 1>{}([&](auto ik) {
500 a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
501 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
502 make_tuple(m0, I0, k0, ik))>{}];
503 b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
504 b_thread_buf[Number<b_thread_desc_.CalculateOffset(
505 make_tuple(n0, I0, k0, ik))>{}];
506 });
507
508 using mfma_input_type =
509 typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
510
511 constexpr index_t c_offset =
512 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
513
514 xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
515 b_thread_vec.template AsType<mfma_input_type>(),
516 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
517 });
518 });
519 });
520 __builtin_amdgcn_sched_barrier(0);
521 }
522 }
523
524 protected:
525 using Base::a_thread_copy_;
526 using Base::a_thread_desc_;
527 using Base::b_thread_copy_;
528 using Base::b_thread_desc_;
529 using Base::c_thread_desc_;
530};
531
532} // namespace ck
Definition ck.hpp:268
__host__ __device__ constexpr auto make_static_buffer(Number< N >)
Definition static_buffer.hpp:186
int32_t index_t
Definition ck.hpp:299
integral_constant< index_t, N > Number
Definition number.hpp:12
TailNumber
Definition blkgemmpipe_scheduler.hpp:31
@ Full
Definition blkgemmpipe_scheduler.hpp:49
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__device__ void block_sync_lds()
Definition synchronization.hpp:16
__host__ __device__ BlockwiseGemmXdlops_pipeline_base(Tuple4 a_origin=CalculateAThreadOriginDataIndex(), Tuple4 b_origin=CalculateBThreadOriginDataIndex())
Constructor for BlockwiseGemmXdlops_pipeline_base.
Definition blockwise_gemm_pipeline_xdlops_base.hpp:222
__host__ static __device__ constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:280
__host__ static __device__ constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:239
static constexpr auto xdlops_gemm
Definition blockwise_gemm_pipeline_xdlops_base.hpp:54
conditional_t< std::is_same< ComputeDataType, ck::tf32_t >::value, float, ComputeDataType > ComputeDataTypeBuf
Definition blockwise_gemm_pipeline_xdlops_base.hpp:57
static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_k
Definition blockwise_gemm_pipeline_xdlops_base.hpp:360
static constexpr auto I1
Definition blockwise_gemm_pipeline_xdlops_base.hpp:37
__host__ static __device__ constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:266
__host__ static __device__ constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:294
static constexpr index_t AMmaKStride
Definition blockwise_gemm_pipeline_xdlops_base.hpp:60
__host__ static __device__ constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:253
ck::BlockwiseGemmXdlops_pipeline_hotloop_inst< BlockSize, MPerBlock, NPerBlock, KPerBlock, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, A_K1, B_K1, A_K1, B_K1, MRepeat, NRepeat, MPerXDL, NPerXDL, xdlops_gemm.KPerXdlops > HotLoopInstList
Definition blockwise_gemm_pipeline_xdlops_base.hpp:82
__host__ __device__ constexpr auto & GetCThreadBuffer()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:111
static constexpr auto I0
Definition blockwise_gemm_pipeline_xdlops_base.hpp:36
static __device__ auto CalculateCThreadOriginDataIndex(Number< m0 >, Number< n0 >, Number< xdlops_i >, Number< blk_i >)
Definition blockwise_gemm_pipeline_xdlops_base.hpp:160
static __device__ auto CalculateCThreadOriginDataIndex8D(Number< m0 >, Number< n0 >, Number< xdlops_i >, Number< blk_i >)
Definition blockwise_gemm_pipeline_xdlops_base.hpp:189
static constexpr index_t KRepeat
Definition blockwise_gemm_pipeline_xdlops_base.hpp:64
static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k
Definition blockwise_gemm_pipeline_xdlops_base.hpp:359
static constexpr index_t BMmaKStride
Definition blockwise_gemm_pipeline_xdlops_base.hpp:61
__host__ static __device__ constexpr auto MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N &c_grid_desc_g_m_n)
Definition blockwise_gemm_pipeline_xdlops_base.hpp:341
__host__ static __device__ constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:307
__host__ static __device__ constexpr auto MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N &c_grid_desc_m_n)
Definition blockwise_gemm_pipeline_xdlops_base.hpp:324
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, const BScaleGridDesc &b_scale_grid_desc, const BScaleThreadDesc &b_scale_thread_desc, BScaleThreadTransfer &b_scale_thread_copy, const BScaleGridBuffer &b_scale_grid_buf, const BScaleThreadTransferStep &b_scale_thread_copy_step, index_t num_loop, index_t num_loop_per_scale) const
Definition blockwise_gemm_pipeline_xdlops_v3_b_scale.hpp:287
BlockwiseGemmXdlops_pipeline_base< BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack > Base
Definition blockwise_gemm_pipeline_xdlops_v3_b_scale.hpp:102
Definition blockwise_gemm_pipeline_xdlops_v3_b_scale.hpp:37
Definition functional2.hpp:33
Definition dtype_vector.hpp:10