blkgemmpipe_scheduler.hpp Source File

blkgemmpipe_scheduler.hpp Source File#

Composable Kernel: blkgemmpipe_scheduler.hpp Source File
blkgemmpipe_scheduler.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
8
9namespace ck {
10
12{
13 // For GEMM
14 v1, // Naive
15 v2, // Mem
16 v3, // Comp
17 v4, // Comp, double lds buffer
18 v5, // Comp, double global prefetch register buffer
19
20 // For GEMM with preshuffled weight
21 // v1, single lds buffer
22 // v2, double lds buffer
23};
29
30enum struct TailNumber
31{
32 // Single / Double buffer pipeline
35
36 // Long prefetch pipeline, up to 8
44
45 // Unroll stages > Prefetch stages, number of loop is multiple of unroll stages
47 // Unroll stages <= Prefetch stages, number of loop is multiple of unroll stages add
48 // prefetchstages
50};
51
53{
54 SCHED_GROUP_MFMA = 0x008, // Matrix FMA instructions
55 SCHED_GROUP_VMEM = 0x020, // Global memory operations
56 SCHED_GROUP_LDS_READ = 0x100, // LDS read operations
57 SCHED_GROUP_LDS_WRITE = 0x200 // LDS write operations
58};
59
60template <index_t BlockSize,
61 index_t MPerBlock,
62 index_t NPerBlock,
63 index_t KPerBlock,
64 index_t ABufferLoadWidth,
65 index_t BBufferLoadWidth,
66 index_t ALDSWriteWidth,
67 index_t BLDSWriteWidth,
68 index_t ALDSReadWidth,
69 index_t BLDSReadWidth,
70 index_t MRepeat,
71 index_t NRepeat,
72 index_t MPerXDL,
73 index_t NPerXDL,
74 index_t KPerXDL,
75 bool IsF4F6 = false>
76struct BlockwiseGemmXdlops_pipeline_hotloop_inst
77{
78 static constexpr index_t WaveNumM = MPerBlock / (MRepeat * MPerXDL);
79 static constexpr index_t WaveNumN = NPerBlock / (NRepeat * NPerXDL);
80 static constexpr index_t WaveSize = BlockSize / WaveNumM / WaveNumN;
81
82 static constexpr index_t A_LDS_Read_Width = ALDSReadWidth;
83 static constexpr index_t B_LDS_Read_Width = BLDSReadWidth;
84
85 static constexpr index_t A_Buffer_Load_Inst_Num =
86 MPerBlock * KPerBlock / (BlockSize * ABufferLoadWidth);
87 static constexpr index_t B_Buffer_Load_Inst_Num =
88 NPerBlock * KPerBlock / (BlockSize * BBufferLoadWidth);
89
90 static constexpr index_t A_LDS_Write_Inst_Num =
91 MPerBlock * KPerBlock / (BlockSize * ALDSWriteWidth);
92 static constexpr index_t B_LDS_Write_Inst_Num =
93 NPerBlock * KPerBlock / (BlockSize * BLDSWriteWidth);
94
95 static constexpr index_t A_LDS_Read_Inst_Num =
96 WaveNumN * MPerBlock * KPerBlock / (BlockSize * ALDSReadWidth);
97 static constexpr index_t B_LDS_Read_Inst_Num =
98 WaveNumM * NPerBlock * KPerBlock / (BlockSize * BLDSReadWidth);
99
100 static constexpr index_t C_MFMA_Inst_Num =
101 MPerBlock * NPerBlock * KPerBlock / (BlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL);
102
103 static constexpr index_t C_MFMA_SpeedUp = IsF4F6 ? 2 : 1;
104
105 static constexpr index_t C_MFMA_Inst_Cycle = []() {
106 if constexpr(NPerXDL == 16)
107 {
108 return KPerXDL == 128 ? 32 / C_MFMA_SpeedUp : 16 / C_MFMA_SpeedUp;
109 }
110 else if constexpr(NPerXDL == 32)
111 {
112 return KPerXDL == 64 ? 64 / C_MFMA_SpeedUp : 32 / C_MFMA_SpeedUp;
113 }
114 }();
115
116 static constexpr auto Print()
117 {
118 printf(" Blk/Wave Size: %d, %d, M/N/K PerBlk: %d, %d, %d, M/N/K PerXdl: %d, %d, %d\n",
119 BlockSize,
120 WaveSize,
121 MPerBlock,
122 NPerBlock,
123 KPerBlock,
124 MPerXDL,
125 NPerXDL,
126 KPerXDL);
127
128 printf(" A/B buffer load inst: %d, %d\n A/B LDS write inst: %d, %d\n A/B LDS read inst: "
129 "%d, %d\n C MFMA inst: %d C MFMA cycle: %d\n"
130 "A/B LDS read width: %d, %d, A/B LDS write width: %d, %d, A/B buffer load width: "
131 "%d/ %d\n",
142 ALDSWriteWidth,
143 BLDSWriteWidth,
144 ABufferLoadWidth,
145 BBufferLoadWidth);
146 }
147};
148
149} // namespace ck
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
BlockGemmPipelineVersion
Definition blkgemmpipe_scheduler.hpp:12
@ v3
Definition blkgemmpipe_scheduler.hpp:16
@ v5
Definition blkgemmpipe_scheduler.hpp:18
TailNumber
Definition blkgemmpipe_scheduler.hpp:31
@ One
Definition blkgemmpipe_scheduler.hpp:37
@ Seven
Definition blkgemmpipe_scheduler.hpp:43
@ Even
Definition blkgemmpipe_scheduler.hpp:34
@ Odd
Definition blkgemmpipe_scheduler.hpp:33
@ Four
Definition blkgemmpipe_scheduler.hpp:40
@ Two
Definition blkgemmpipe_scheduler.hpp:38
@ Full
Definition blkgemmpipe_scheduler.hpp:49
@ Three
Definition blkgemmpipe_scheduler.hpp:39
@ Empty
Definition blkgemmpipe_scheduler.hpp:46
@ Five
Definition blkgemmpipe_scheduler.hpp:41
@ Six
Definition blkgemmpipe_scheduler.hpp:42
SchedulerGroup
Definition blkgemmpipe_scheduler.hpp:53
@ SCHED_GROUP_LDS_READ
Definition blkgemmpipe_scheduler.hpp:56
@ SCHED_GROUP_MFMA
Definition blkgemmpipe_scheduler.hpp:54
@ SCHED_GROUP_LDS_WRITE
Definition blkgemmpipe_scheduler.hpp:57
@ SCHED_GROUP_VMEM
Definition blkgemmpipe_scheduler.hpp:55
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
@ Interwave
Definition blkgemmpipe_scheduler.hpp:27
@ v2
Definition gridwise_gemm_pipeline_selector.hpp:20
@ v4
Definition gridwise_gemm_pipeline_selector.hpp:22
@ v1
Definition gridwise_gemm_pipeline_selector.hpp:19
unsigned int uint32_t
Definition stdint.h:126
static constexpr index_t WaveSize
Definition blockwise_gemm_pipeline_xdlops.hpp:37
static constexpr auto Print()
Definition blkgemmpipe_scheduler.hpp:116
static constexpr index_t WaveNumN
Definition blockwise_gemm_pipeline_xdlops.hpp:36
static constexpr index_t WaveNumM
Definition blockwise_gemm_pipeline_xdlops.hpp:35