device_gemm_xdl.hpp Source File

device_gemm_xdl.hpp Source File#

Composable Kernel: device_gemm_xdl.hpp Source File
device_gemm_xdl.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8
18
19namespace ck {
20namespace tensor_operation {
21namespace device {
22
23template <typename ADataType,
24 typename BDataType,
25 typename CDataType,
26 typename AccDataType,
27 typename ALayout,
28 typename BLayout,
29 typename CLayout,
30 typename AElementwiseOperation,
31 typename BElementwiseOperation,
32 typename CElementwiseOperation,
33 GemmSpecialization GemmSpec,
34 ck::index_t BlockSize,
35 ck::index_t MPerBlock,
36 ck::index_t NPerBlock,
37 ck::index_t K0PerBlock,
38 ck::index_t K1,
39 ck::index_t MPerXDL,
40 ck::index_t NPerXDL,
41 ck::index_t MXdlPerWave,
42 ck::index_t NXdlPerWave,
43 typename ABlockTransferThreadClusterLengths_K0_M_K1,
44 typename ABlockTransferThreadClusterArrangeOrder,
45 typename ABlockTransferSrcAccessOrder,
46 ck::index_t ABlockTransferSrcVectorDim,
47 ck::index_t ABlockTransferSrcScalarPerVector,
48 ck::index_t ABlockTransferDstScalarPerVector_K1,
49 bool ABlockLdsAddExtraM,
50 typename BBlockTransferThreadClusterLengths_K0_N_K1,
51 typename BBlockTransferThreadClusterArrangeOrder,
52 typename BBlockTransferSrcAccessOrder,
53 ck::index_t BBlockTransferSrcVectorDim,
54 ck::index_t BBlockTransferSrcScalarPerVector,
55 ck::index_t BBlockTransferDstScalarPerVector_K1,
56 bool BBlockLdsAddExtraN,
57 ck::index_t CThreadTransferSrcDstVectorDim,
58 ck::index_t CThreadTransferDstScalarPerVector,
59 ck::index_t NumPrefetch = 1,
62struct DeviceGemmXdl : public DeviceGemm<ALayout,
63 BLayout,
64 CLayout,
65 ADataType,
66 BDataType,
67 CDataType,
68 AElementwiseOperation,
69 BElementwiseOperation,
70 CElementwiseOperation>
71{
73 static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
74 static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
75
76 static constexpr auto I0 = Number<0>{};
77 static constexpr auto I1 = Number<1>{};
78 static constexpr auto I2 = Number<2>{};
79
80 static constexpr auto K1Number = Number<K1>{};
81
82 // GridwiseGemm
83 template <index_t NXdlPerWave_>
85 BlockSize,
86 ADataType, // TODO: distinguish A/B datatype
87 AccDataType,
88 CDataType,
90 ALayout,
91 BLayout,
92 CLayout,
93 AElementwiseOperation,
94 BElementwiseOperation,
95 CElementwiseOperation,
96 GemmSpec,
97 MPerBlock,
98 NPerBlock,
99 K0PerBlock,
100 MPerXDL,
101 NPerXDL,
102 K1,
103 MXdlPerWave,
104 NXdlPerWave_,
105 ABlockTransferThreadClusterLengths_K0_M_K1,
106 ABlockTransferThreadClusterArrangeOrder,
107 ABlockTransferSrcAccessOrder,
108 ABlockTransferSrcVectorDim,
109 ABlockTransferSrcScalarPerVector,
110 ABlockTransferDstScalarPerVector_K1,
111 false, // AThreadTransferSrcResetCoordinateAfterRun,
112 ABlockLdsAddExtraM,
113 BBlockTransferThreadClusterLengths_K0_N_K1,
114 BBlockTransferThreadClusterArrangeOrder,
115 BBlockTransferSrcAccessOrder,
116 BBlockTransferSrcVectorDim,
117 BBlockTransferSrcScalarPerVector,
118 BBlockTransferDstScalarPerVector_K1,
119 false, // BThreadTransferSrcResetCoordinateAfterRun,
120 BBlockLdsAddExtraN,
121 Sequence<0, 2, 4, 5, 6, 1, 3, 7>, // CThreadTransferSrcDstAccessOrder,
122 CThreadTransferSrcDstVectorDim,
123 CThreadTransferDstScalarPerVector,
124 NumPrefetch,
125 LoopSched,
126 PipelineVer>;
129
131
132 // Invoker
133 struct Invoker : public BaseInvoker
134 {
135 template <typename GridwiseGemm>
136 float RunImp(const typename GridwiseGemm::Argument& karg,
137 const StreamConfig& stream_config = StreamConfig{})
138 {
139 if(stream_config.log_level_ > 0)
140 {
141 karg.Print();
142 }
143
144 if(!GridwiseGemm::CheckValidity(karg))
145 {
146 throw std::runtime_error(
147 "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext has invalid setting");
148 }
149
150 const auto [gdx, gdy, gdz] = GridwiseGemm::CalculateGridSize(karg.M, karg.N);
151
152 float ave_time = 0;
153
154 if(GridwiseGemm::CalculateHasMainKBlockLoop(karg.K))
155 {
157
158 ave_time = launch_and_time_kernel(
159 stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg);
160 }
161 else
162 {
164
165 ave_time = launch_and_time_kernel(
166 stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg);
167 }
168
169 return ave_time;
170 }
171
173
174 // polymorphic
175 float Run(const BaseArgument* p_arg,
176 const StreamConfig& stream_config = StreamConfig{}) override
177 {
178 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
179 }
180 };
181
182 static constexpr bool IsValidCompilationParameter()
183 {
184 // TODO: properly implement this check
185 return true;
186 }
187
188 static bool IsSupportedArgument(const Argument& karg)
189 {
190 if(ck::get_device_name() == "gfx908")
191 {
194 {
195 return false;
196 }
197 }
199 {
202 {
203 return false;
204 }
205 }
206 else
207 {
208 return false;
209 }
210
211 if(karg.K % K1 != 0)
212 {
213 return false;
214 }
215 if(get_warp_size() == 64)
216 {
217 if constexpr(NXdlPerWave64 > 0)
218 {
220 }
221 }
222 else
223 {
224 if constexpr(NXdlPerWave32 > 0)
225 {
227 reinterpret_cast<const typename GridwiseGemm32::Argument&>(karg));
228 }
229 }
230 return false;
231 }
232
233 // polymorphic
234 bool IsSupportedArgument(const BaseArgument* p_arg) override
235 {
236 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
237 }
238
239 static auto MakeArgument(const ADataType* p_a,
240 const BDataType* p_b,
241 CDataType* p_c,
242 index_t M,
243 index_t N,
244 index_t K,
245 index_t StrideA,
246 index_t StrideB,
247 index_t StrideC,
248 AElementwiseOperation,
249 BElementwiseOperation,
250 CElementwiseOperation)
251 {
252 return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC};
253 }
254
255 static auto MakeInvoker() { return Invoker{}; }
256
257 // polymorphic
258 std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
259 const void* p_b,
260 void* p_c,
261 index_t M,
262 index_t N,
263 index_t K,
264 index_t StrideA,
265 index_t StrideB,
266 index_t StrideC,
267 AElementwiseOperation,
268 BElementwiseOperation,
269 CElementwiseOperation) override
270 {
271 return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
272 static_cast<const BDataType*>(p_b),
273 static_cast<CDataType*>(p_c),
274 M,
275 N,
276 K,
277 StrideA,
278 StrideB,
279 StrideC);
280 }
281
282 // polymorphic
283 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
284 {
285 return std::make_unique<Invoker>(Invoker{});
286 }
287
288 // polymorphic
289 std::string GetTypeString() const override
290 {
291 auto str = std::stringstream();
292
293 std::map<LoopScheduler, std::string> LoopSchedToString{
294 {LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}};
295
296 std::map<PipelineVersion, std::string> PipelineVersionToString{{PipelineVersion::v1, "v1"},
297 {PipelineVersion::v2, "v2"}};
298
299 // clang-format off
300 str << "DeviceGemmXdl"
301 << "<"
302 << BlockSize << ", "
303 << MPerBlock << ", "
304 << NPerBlock << ", "
305 << K0PerBlock << ", "
306 << K1 << ", "
307 << MPerXDL << ", "
308 << NPerXDL << ", "
309 << MXdlPerWave << ", "
310 << NXdlPerWave << ", "
311 << ABlockTransferSrcScalarPerVector << ", "
312 << ABlockTransferDstScalarPerVector_K1 << ", "
313 << BBlockTransferSrcScalarPerVector << ", "
314 << BBlockTransferDstScalarPerVector_K1
315 << ">"
316 << " NumPrefetch: "
317 << NumPrefetch << ", "
318 << "LoopScheduler: "
319 << LoopSchedToString[LoopSched] << ", "
320 << "PipelineVersion: "
321 << PipelineVersionToString[PipelineVer];
322 // clang-format on
323
324 return str.str();
325 }
326};
327
328} // namespace device
329} // namespace tensor_operation
330} // namespace ck
#define INVOKER_RUN3_IMPL
Definition device_base.hpp:114
#define GET_NXDL_PER_WAVE_IMPL
Definition device_base.hpp:81
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
Definition convolution_backward_data_specialization.hpp:8
GemmSpecialization
Definition gemm_specialization.hpp:11
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
bool is_lds_direct_load_supported()
Definition host_utility/device_prop.hpp:101
int32_t index_t
Definition ck.hpp:299
@ Set
Definition ck.hpp:278
integral_constant< index_t, N > Number
Definition number.hpp:12
std::string get_device_name()
Definition host_utility/device_prop.hpp:19
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
constexpr bool is_same_v
Definition type.hpp:283
__global__ void kernel_gemm_xdlops_v2r3(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, const CGridDesc_M_N c_grid_desc_m_n)
Definition gridwise_gemm_xdlops_v2r3.hpp:34
LoopScheduler
Definition loop_scheduler.hpp:15
@ Default
Definition loop_scheduler.hpp:16
@ Interwave
Definition loop_scheduler.hpp:17
PipelineVersion
Definition gridwise_gemm_pipeline_selector.hpp:18
@ v2
Definition gridwise_gemm_pipeline_selector.hpp:20
@ v1
Definition gridwise_gemm_pipeline_selector.hpp:19
constexpr LoopScheduler make_default_loop_scheduler()
Definition loop_scheduler.hpp:20
Definition ck/stream_config.hpp:10
Definition gridwise_gemm_xdlops_v2r3.hpp:240
Definition gridwise_gemm_xdlops_v2r3.hpp:814
Definition utility/sequence.hpp:43
Definition device_base.hpp:197
Definition device_gemm.hpp:22
INVOKER_RUN3_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_gemm_xdl.hpp:175
float RunImp(const typename GridwiseGemm::Argument &karg, const StreamConfig &stream_config=StreamConfig{})
Definition device_gemm_xdl.hpp:136
Definition device_gemm_xdl.hpp:71
static bool IsSupportedArgument(const Argument &karg)
Definition device_gemm_xdl.hpp:188
static constexpr auto K1Number
Definition device_gemm_xdl.hpp:80
GridwiseGemmBase< NXdlPerWave32 > GridwiseGemm32
Definition device_gemm_xdl.hpp:128
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_gemm_xdl.hpp:73
static constexpr auto I0
Definition device_gemm_xdl.hpp:76
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_gemm_xdl.hpp:234
typename GridwiseGemm64::Argument Argument
Definition device_gemm_xdl.hpp:130
static auto MakeInvoker()
Definition device_gemm_xdl.hpp:255
std::string GetTypeString() const override
Definition device_gemm_xdl.hpp:289
static constexpr auto I2
Definition device_gemm_xdl.hpp:78
static constexpr auto I1
Definition device_gemm_xdl.hpp:77
static auto MakeArgument(const ADataType *p_a, const BDataType *p_b, CDataType *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation)
Definition device_gemm_xdl.hpp:239
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext< BlockSize, ADataType, AccDataType, CDataType, InMemoryDataOperationEnum::Set, ALayout, BLayout, CLayout, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, MPerBlock, NPerBlock, K0PerBlock, MPerXDL, NPerXDL, K1, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, ABlockLdsAddExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, BBlockLdsAddExtraN, Sequence< 0, 2, 4, 5, 6, 1, 3, 7 >, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector, NumPrefetch, LoopSched, PipelineVer > GridwiseGemmBase
Definition device_gemm_xdl.hpp:84
static constexpr bool IsValidCompilationParameter()
Definition device_gemm_xdl.hpp:182
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_gemm_xdl.hpp:283
static constexpr auto NXdlPerWave32
Definition device_gemm_xdl.hpp:74
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_gemm_xdl.hpp:127
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation) override
Definition device_gemm_xdl.hpp:258