device_batched_contraction_multiple_d_wmma_cshuffle.hpp Source File

device_batched_contraction_multiple_d_wmma_cshuffle.hpp Source File#

Composable Kernel: device_batched_contraction_multiple_d_wmma_cshuffle.hpp Source File
device_batched_contraction_multiple_d_wmma_cshuffle.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8
9#include "ck/utility/env.hpp"
21
22namespace ck {
23namespace tensor_operation {
24namespace device {
25
26// Tensor Contraction:
27// input : A
28// input : B
29// input : D0, D1, ...
30// output : E
31// C = a_op(A) * b_op(B)
32// E = cde_op(C, D0, D1, ...)
33// Assume:
34// A[G0, G1, ..., M0, M1, M2, ..., K0, K1, K2, ...]
35// B[G0, G1, ..., N0, N1, N2, ..., K0, K1, K2, ...]
36// D[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...]
37// E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...]
38
39// NOTE: TensorSpecialization::Packed specialized tensor is "packed" in a sense that each inner
40// dimension in a dimension group (eg [G0, G1] in Gs, [M0, M1, M2] in Ms, etc.) are contiguous and
41// ordered. Not in a sense that the tensor [G0, G1, ..., M0, M1, ..., N0, N1...] can be permuted
42// while still being a contiguous, unpadded tensor. In other words, it merely degenerates into
43// TensorSpecialization::Default with NumDimG/M/N/K = 1
44//
45// Detail- Packed tensor satisfies
46// stride_0 = 1
47// stride_i = stride_{i - 1} * extent_{i - 1}
48// So tensor
49// [G0, G1, G2, M, N]
50// transposed into tensor
51// [G0, G2, G1, M, N]
52// with strides
53// [G2 * G1 * M * N, G1 * M * N, M * N, N, 1]
54// is again a packed tensor. MakeGridDescriptor() currently just merges dimensions and ignores some
55// strides from input tensor extents so finer dimension information is lost. Merging dimensions is
56// essentially a degenerated case of TensorSpecialization::Default with NumDimG/M/N/K = 1.
57//
58// Might need to expose dimension order to the interface to fully support
59// TensorSpecialization::Packed in a traditional sense of "packed" tensor
60template <index_t NumDimG,
61 index_t NumDimM,
62 index_t NumDimN,
63 index_t NumDimK,
64 typename ADataType,
65 typename BDataType,
66 typename AccDataType,
67 typename CShuffleDataType,
68 typename DsDataType,
69 typename EDataType,
70 typename AElementwiseOperation,
71 typename BElementwiseOperation,
72 typename CDEElementwiseOperation,
73 GemmSpecialization GemmSpec,
77 ck::index_t NumPrefetch,
78 ck::index_t BlockSize,
79 ck::index_t MPerBlock,
80 ck::index_t NPerBlock,
81 ck::index_t KPerBlock,
82 ck::index_t K1,
83 ck::index_t MPerWmma,
84 ck::index_t NPerWmma,
85 ck::index_t MRepeat,
86 ck::index_t NRepeat,
87 typename ABlockTransferThreadClusterLengths_K0_M_K1,
88 typename ABlockTransferThreadClusterArrangeOrder,
89 typename ABlockTransferSrcAccessOrder,
90 ck::index_t ABlockTransferSrcVectorDim,
91 ck::index_t ABlockTransferSrcScalarPerVector,
92 ck::index_t ABlockTransferDstScalarPerVector_K1,
93 bool ABlockLdsAddExtraM,
94 typename BBlockTransferThreadClusterLengths_K0_N_K1,
95 typename BBlockTransferThreadClusterArrangeOrder,
96 typename BBlockTransferSrcAccessOrder,
97 ck::index_t BBlockTransferSrcVectorDim,
98 ck::index_t BBlockTransferSrcScalarPerVector,
99 ck::index_t BBlockTransferDstScalarPerVector_K1,
100 bool BBlockLdsAddExtraN,
101 index_t CShuffleMRepeatPerShuffle,
102 index_t CShuffleNRepeatPerShuffle,
103 typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
104 index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
108 : public DeviceBatchedContractionMultipleD<NumDimG,
109 NumDimM,
110 NumDimN,
111 NumDimK,
112 ADataType,
113 BDataType,
114 DsDataType,
115 EDataType,
116 AElementwiseOperation,
117 BElementwiseOperation,
118 CDEElementwiseOperation>
119{
121 static constexpr index_t NumDTensor = DsDataType::Size();
122
123 static constexpr auto I0 = Number<0>{};
124 static constexpr auto I1 = Number<1>{};
125 static constexpr auto I2 = Number<2>{};
126 static constexpr auto I3 = Number<3>{};
127 static constexpr auto I4 = Number<4>{};
128 static constexpr auto I5 = Number<5>{};
129 static constexpr auto I6 = Number<6>{};
130 // K1 = Max Vector Access Pixels
131 static constexpr auto K1Number = Number<K1>{};
132
133 static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
134 static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
135 static constexpr auto WmmaK = K1 == 16 ? 32 : 16;
136
137 static constexpr auto MaxVectorLoadA = K1 * sizeof(ADataType) == 16 ? true : false;
138 static constexpr auto MaxVectorLoadB = K1 * sizeof(BDataType) == 16 ? true : false;
139
140 static constexpr auto AEnableLds_auto =
141 (NWaves == 1 && (MaxVectorLoadA || MRepeat == 1)) ? false : true;
142 static constexpr auto BEnableLds_auto =
143 (MWaves == 1 && (MaxVectorLoadB || NRepeat == 1)) ? false : true;
144
145 // If true, LDS is used unconditionally
146 static constexpr auto AEnableLds_manu = false;
147 static constexpr auto BEnableLds_manu = false;
148
149 static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu || (NumPrefetch > 1);
150 static constexpr auto BEnableLds = BEnableLds_auto || BEnableLds_manu || (NumPrefetch > 1);
151
152 static constexpr auto matrix_padder =
153 MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
154
155 // Assume: A[G0, G1, ..., M0, M1, M2, ..., K0, K1, K2, ...]
156 static auto MakeAGridDescriptor(const std::vector<index_t>& a_gs_ms_ks_lengths_vec,
157 const std::vector<index_t>& a_gs_ms_ks_strides_vec)
158 {
159 assert(a_gs_ms_ks_lengths_vec.size() == NumDimG + NumDimM + NumDimK &&
160 a_gs_ms_ks_strides_vec.size() == NumDimG + NumDimM + NumDimK);
161
162 const auto to_tuple = [&](auto& vec, auto start, auto end) {
163 return generate_tuple([&](auto i) { return vec[start + i]; }, Number<end - start>{});
164 };
165
166 const auto a_ms_ks_lengths = to_tuple(
167 a_gs_ms_ks_lengths_vec, Number<NumDimG>{}, Number<NumDimG + NumDimM + NumDimK>{});
168 const auto a_ms_ks_strides = to_tuple(
169 a_gs_ms_ks_strides_vec, Number<NumDimG>{}, Number<NumDimG + NumDimM + NumDimK>{});
170
171 // dimension Ids for M0, M1, ...
172 constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{};
173
174 // dimension Ids for K0, K1, ...
175 constexpr auto kDimIds =
177
178 // lengths for M0, M1, ...
179 const auto mLengths = get_container_subset(a_ms_ks_lengths, mDimIds);
180
181 // lengths for K0, K1, ...
182 const auto kLengths = get_container_subset(a_ms_ks_lengths, kDimIds);
183
184 const auto a_grid_desc_m_k = [&]() {
185 if constexpr(ASpec == TensorSpecialization::Packed)
186 {
187 auto M = container_reduce(mLengths, math::multiplies{}, Number<1>{});
188 auto K = container_reduce(kLengths, math::multiplies{}, Number<1>{});
189 const auto a_grid_desc_mraw_kraw = make_naive_tensor_descriptor(
190 make_tuple(M, K),
191 make_tuple(a_ms_ks_strides[Number<NumDimM - 1>{}],
192 a_ms_ks_strides[Number<NumDimM + NumDimK - 1>{}]));
193 return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
194 }
195 else
196 {
197 // naive tensor A[M0, M1, M2, ..., K0, K1, K2...]
198 const auto a_grid_desc_ms_ks =
199 make_naive_tensor_descriptor(a_ms_ks_lengths, a_ms_ks_strides);
200
201 // transformed tensor A[MRaw = M0 * M1 * M2 * ... , KRaw = K0 * K1 * K2 * ...]
202 const auto a_grid_desc_mraw_kraw = transform_tensor_descriptor(
203 a_grid_desc_ms_ks,
205 make_tuple(mDimIds, kDimIds),
207
208 return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
209 }
210 }();
211
212 const auto M = a_grid_desc_m_k.GetLength(I0);
213 const auto K = a_grid_desc_m_k.GetLength(I1);
214 assert(K % K1 == 0);
215
216 if constexpr(AEnableLds)
217 {
218 const index_t K0 = K / K1;
219
221 a_grid_desc_m_k,
226 }
227 else
228 {
229 constexpr auto A_KRow = 2;
230 constexpr auto A_K0PerWmma = WmmaK / A_KRow / K1Number;
231 const auto A_KWmma = K / WmmaK;
232
233 const auto M0 = M / MPerBlock;
234 // 0 1 0 1 2 3 4 5 6
235 // M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1
237 a_grid_desc_m_k,
241 make_tuple(M0 * MRepeat, Number<MWaves>{}, Number<MPerWmma>{}))),
244 }
245 }
246
247 // Assume: B[G0, G1, ..., N0, N1, N2, ..., K0, K1, K2, ...]
248 static auto MakeBGridDescriptor(const std::vector<index_t>& b_gs_ns_ks_lengths_vec,
249 const std::vector<index_t>& b_gs_ns_ks_strides_vec)
250 {
251 assert(b_gs_ns_ks_lengths_vec.size() == NumDimG + NumDimN + NumDimK &&
252 b_gs_ns_ks_strides_vec.size() == NumDimG + NumDimN + NumDimK);
253
254 const auto to_tuple = [&](auto& vec, auto start, auto end) {
255 return generate_tuple([&](auto i) { return vec[start + i]; }, Number<end - start>{});
256 };
257
258 const auto b_ns_ks_lengths = to_tuple(
259 b_gs_ns_ks_lengths_vec, Number<NumDimG>{}, Number<NumDimG + NumDimN + NumDimK>{});
260 const auto b_ns_ks_strides = to_tuple(
261 b_gs_ns_ks_strides_vec, Number<NumDimG>{}, Number<NumDimG + NumDimN + NumDimK>{});
262
263 // dimension Ids for N0, N1, ...
264 constexpr auto nDimIds = typename arithmetic_sequence_gen<0, NumDimN, 1>::type{};
265
266 // dimension Ids for K0, K1, ...
267 constexpr auto kDimIds =
269
270 // lengths for K0, K1, ...
271 const auto kLengths = get_container_subset(b_ns_ks_lengths, kDimIds);
272
273 // lengths for N0, N1, ...
274 const auto nLengths = get_container_subset(b_ns_ks_lengths, nDimIds);
275
276 const auto b_grid_desc_n_k = [&]() {
277 if constexpr(BSpec == TensorSpecialization::Packed)
278 {
279 auto N = container_reduce(nLengths, math::multiplies{}, Number<1>{});
280 auto K = container_reduce(kLengths, math::multiplies{}, Number<1>{});
281 const auto b_grid_desc_nraw_kraw = make_naive_tensor_descriptor(
282 make_tuple(N, K),
283 make_tuple(b_ns_ks_strides[Number<NumDimN - 1>{}],
284 b_ns_ks_strides[Number<NumDimN + NumDimK - 1>{}]));
285 return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
286 }
287 else
288 {
289 // naive tensor B[N0, N1, N2, ..., K0, K1, K2, ...]
290 const auto b_grid_desc_ns_ks =
291 make_naive_tensor_descriptor(b_ns_ks_lengths, b_ns_ks_strides);
292
293 // transformed tensor B[NRaw = N0 * N1 * N2 * ..., KRaw = K0 * K1 * K2 * ...]
294 const auto b_grid_desc_nraw_kraw = transform_tensor_descriptor(
295 b_grid_desc_ns_ks,
297 make_tuple(nDimIds, kDimIds),
299
300 return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
301 }
302 }();
303
304 const auto N = b_grid_desc_n_k.GetLength(I0);
305 const auto K = b_grid_desc_n_k.GetLength(I1);
306 assert(K % K1 == 0);
307
308 if constexpr(BEnableLds)
309 {
310 const index_t K0 = K / K1;
311
313 b_grid_desc_n_k,
318 }
319 else
320 {
321 constexpr auto B_KRow = 2;
322 constexpr auto B_K0PerWmma = WmmaK / B_KRow / K1Number;
323 const auto B_KWmma = K / WmmaK;
324
325 const auto N0 = N / NPerBlock;
326 // 0 1 0 1 2 3 4 5 6
327 // M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1
329 b_grid_desc_n_k,
333 make_tuple(N0 * NRepeat, Number<NWaves>{}, Number<NPerWmma>{}))),
336 }
337 }
338
339 // assume E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
340 static auto MakeEGridDescriptor_M_N(const std::vector<index_t>& e_gs_ms_ns_lengths_vec,
341 const std::vector<index_t>& e_gs_ms_ns_strides_vec)
342 {
343 assert(e_gs_ms_ns_lengths_vec.size() == NumDimG + NumDimM + NumDimN &&
344 e_gs_ms_ns_strides_vec.size() == NumDimG + NumDimM + NumDimN);
345
346 const auto to_tuple = [&](auto& vec, auto start, auto end) {
347 return generate_tuple([&](auto i) { return vec[start + i]; }, Number<end - start>{});
348 };
349
350 const auto e_ms_ns_lengths = to_tuple(
351 e_gs_ms_ns_lengths_vec, Number<NumDimG>{}, Number<NumDimG + NumDimM + NumDimN>{});
352 const auto e_ms_ns_strides = to_tuple(
353 e_gs_ms_ns_strides_vec, Number<NumDimG>{}, Number<NumDimG + NumDimM + NumDimN>{});
354
355 // dimension Ids for M0, M1, ...
356 constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{};
357
358 // dimension Ids for N0, N1, ...
359 constexpr auto nDimIds =
361
362 // lengths for M0, M1, ...
363 const auto mLengths = get_container_subset(e_ms_ns_lengths, mDimIds);
364
365 // lengths for K0, K1, ...
366 const auto nLengths = get_container_subset(e_ms_ns_lengths, nDimIds);
367
368 if constexpr(DESpec == TensorSpecialization::Packed)
369 {
370 auto M = container_reduce(mLengths, math::multiplies{}, Number<1>{});
371 auto N = container_reduce(nLengths, math::multiplies{}, Number<1>{});
372 const auto e_grid_desc_mraw_nraw = make_naive_tensor_descriptor(
373 make_tuple(M, N),
374 make_tuple(e_ms_ns_strides[Number<NumDimM - 1>{}],
375 e_ms_ns_strides[Number<NumDimM + NumDimN - 1>{}]));
376 return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw);
377 }
378 else
379 {
380 // naive tensor E[M0, M1, M2, ..., N0, N1, N2...]
381 const auto e_grid_desc_ms_ns =
382 make_naive_tensor_descriptor(e_ms_ns_lengths, e_ms_ns_strides);
383
384 // transformed tensor E[MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 * N2 * ...]
385 const auto e_grid_desc_mraw_nraw = transform_tensor_descriptor(
386 e_grid_desc_ms_ns,
388 make_tuple(mDimIds, nDimIds),
390
391 return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw);
392 }
393 }
394
395 // assume E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
396 static auto MakeEGridDescriptor_G_M_N(const std::vector<index_t>& e_gs_ms_ns_lengths_vec,
397 const std::vector<index_t>& e_gs_ms_ns_strides_vec)
398 {
399 assert(e_gs_ms_ns_lengths_vec.size() == NumDimG + NumDimM + NumDimN &&
400 e_gs_ms_ns_strides_vec.size() == NumDimG + NumDimM + NumDimN);
401
402 const auto to_tuple = [&](auto& vec, auto start, auto end) {
403 return generate_tuple([&](auto i) { return vec[start + i]; }, Number<end - start>{});
404 };
405
406 const auto e_gs_ms_ns_lengths =
407 to_tuple(e_gs_ms_ns_lengths_vec, Number<0>{}, Number<NumDimG + NumDimM + NumDimN>{});
408 const auto e_gs_ms_ns_strides =
409 to_tuple(e_gs_ms_ns_strides_vec, Number<0>{}, Number<NumDimG + NumDimM + NumDimN>{});
410
411 // dimension Ids for G0, G1, ...
412 constexpr auto gDimIds = typename arithmetic_sequence_gen<0, NumDimG, 1>::type{};
413
414 // dimension Ids for M0, M1, ...
415 constexpr auto mDimIds =
417
418 // dimension Ids for N0, N1, ...
419 constexpr auto nDimIds = typename arithmetic_sequence_gen<NumDimG + NumDimM,
420 NumDimG + NumDimM + NumDimN,
421 1>::type{};
422
423 // lengths for G0, G1, ...
424 const auto gLengths = get_container_subset(e_gs_ms_ns_lengths, gDimIds);
425
426 // lengths for M0, M1, ...
427 const auto mLengths = get_container_subset(e_gs_ms_ns_lengths, mDimIds);
428
429 // lengths for K0, K1, ...
430 const auto nLengths = get_container_subset(e_gs_ms_ns_lengths, nDimIds);
431
432 if constexpr(DESpec == TensorSpecialization::Packed)
433 {
434 auto G = container_reduce(gLengths, math::multiplies{}, Number<1>{});
435 auto M = container_reduce(mLengths, math::multiplies{}, Number<1>{});
436 auto N = container_reduce(nLengths, math::multiplies{}, Number<1>{});
437 const auto e_grid_desc_g_mraw_nraw = make_naive_tensor_descriptor(
438 make_tuple(G, M, N),
439 make_tuple(e_gs_ms_ns_strides[Number<NumDimG - 1>{}],
440 e_gs_ms_ns_strides[Number<NumDimG + NumDimM - 1>{}],
441 e_gs_ms_ns_strides[Number<NumDimG + NumDimM + NumDimN - 1>{}]));
442 // return matrix_padder.PadCDescriptor_M_N(e_grid_desc_g_mraw_nraw);
443 return e_grid_desc_g_mraw_nraw;
444 }
445 else
446 {
447 // naive tensor E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
448 const auto e_grid_desc_gs_ms_ns =
449 make_naive_tensor_descriptor(e_gs_ms_ns_lengths, e_gs_ms_ns_strides);
450
451 // transformed tensor E[G = G0 * G1 * ..., MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 *
452 // N2 * ...]
453 const auto e_grid_desc_g_mraw_nraw = transform_tensor_descriptor(
454 e_grid_desc_gs_ms_ns,
456 make_merge_transform(mLengths),
457 make_merge_transform(nLengths)),
458 make_tuple(gDimIds, mDimIds, nDimIds),
460
461 // return matrix_padder.PadCDescriptor_M_N(e_grid_desc_g_mraw_nraw);
462 return e_grid_desc_g_mraw_nraw;
463 }
464 }
465
467 const std::array<std::vector<index_t>, NumDTensor>& ds_gs_ms_ns_lengths_vec,
468 const std::array<std::vector<index_t>, NumDTensor>& ds_gs_ms_ns_strides_vec)
469 {
470 return generate_tuple(
471 [&](auto i) {
472 return DeviceOp::MakeEGridDescriptor_M_N(ds_gs_ms_ns_lengths_vec[i],
473 ds_gs_ms_ns_strides_vec[i]);
474 },
476 }
477
479 const std::array<std::vector<index_t>, NumDTensor>& ds_gs_ms_ns_lengths_vec,
480 const std::array<std::vector<index_t>, NumDTensor>& ds_gs_ms_ns_strides_vec)
481 {
482 return generate_tuple(
483 [&](auto i) {
484 return DeviceOp::MakeEGridDescriptor_G_M_N(ds_gs_ms_ns_lengths_vec[i],
485 ds_gs_ms_ns_strides_vec[i]);
486 },
488 }
489
490 // Gridwise descriptor, mapping to whole given provblem.
492 using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N({}, {}));
493
496
498 {
500 index_t batch_stride_B,
501 DsGridDesc_G_M_N ds_grid_desc_g_m_n,
502 EGridDesc_G_M_N e_grid_desc_g_m_n)
503 : batch_stride_A_(batch_stride_A),
504 batch_stride_B_(batch_stride_B),
505 ds_grid_desc_g_m_n_(ds_grid_desc_g_m_n),
506 e_grid_desc_g_m_n_(e_grid_desc_g_m_n)
507 {
508 }
509
510 __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
511 {
512 return static_cast<long_index_t>(g_idx) * batch_stride_A_;
513 }
514
515 __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
516 {
517 return static_cast<long_index_t>(g_idx) * batch_stride_B_;
518 }
519
520 __host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const
521 {
522 std::array<long_index_t, NumDTensor> ds_offset;
523
524 static_for<0, NumDTensor, 1>{}([&](auto i) {
525 ds_offset[i] = static_cast<long_index_t>(g_idx) *
526 ds_grid_desc_g_m_n_[i].CalculateOffset(make_multi_index(1, 0, 0));
527 });
528
529 return ds_offset;
530 }
531
532 __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const
533 {
534 return static_cast<long_index_t>(g_idx) *
535 e_grid_desc_g_m_n_.CalculateOffset(make_multi_index(1, 0, 0));
536 }
537
538 private:
539 index_t batch_stride_A_;
540 index_t batch_stride_B_;
541 DsGridDesc_G_M_N ds_grid_desc_g_m_n_;
542 EGridDesc_G_M_N e_grid_desc_g_m_n_;
543 };
544
545 using AGridDesc = decltype(DeviceOp::MakeAGridDescriptor({}, {}));
546 using BGridDesc = decltype(DeviceOp::MakeBGridDescriptor({}, {}));
547
548 // GridwiseOp
550 // DataType Family
551 ADataType,
552 BDataType,
553 AccDataType,
554 CShuffleDataType,
555 DsDataType,
556 EDataType,
557 // InMemory Data Descriptor
558 AGridDesc,
559 BGridDesc,
562 // ElementwiseOp Family
563 AElementwiseOperation,
564 BElementwiseOperation,
565 CDEElementwiseOperation,
567 // Tiling Family
568 MPerBlock,
569 NPerBlock,
570 KPerBlock,
571 MPerWmma,
572 NPerWmma,
573 K1,
574 MRepeat,
575 NRepeat,
576 // ThreadCluster Family
577 BlockSize,
578 ABlockTransferThreadClusterLengths_K0_M_K1,
579 ABlockTransferThreadClusterArrangeOrder,
580 ABlockTransferSrcAccessOrder,
581 ABlockTransferSrcVectorDim,
582 ABlockTransferSrcScalarPerVector,
583 ABlockTransferDstScalarPerVector_K1,
584 false, // AThreadTransferSrcResetCoordinateAfterRun,
586 ABlockLdsAddExtraM,
587 BBlockTransferThreadClusterLengths_K0_N_K1,
588 BBlockTransferThreadClusterArrangeOrder,
589 BBlockTransferSrcAccessOrder,
590 BBlockTransferSrcVectorDim,
591 BBlockTransferSrcScalarPerVector,
592 BBlockTransferDstScalarPerVector_K1,
593 false, // BThreadTransferSrcResetCoordinateAfterRun,
595 BBlockLdsAddExtraN,
596 CShuffleMRepeatPerShuffle,
597 CShuffleNRepeatPerShuffle,
598 CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
599 CDEShuffleBlockTransferScalarPerVector_NPerBlock,
600 NumPrefetch,
601 LoopSched,
602 PipelineVer>;
603
604 // Argument
605 struct Argument : public BaseArgument
606 {
607 Argument(const void* p_a_grid,
608 const void* p_b_grid,
609 std::array<const void*, NumDTensor> p_ds_grid,
610 void* p_e_grid,
611 const std::vector<index_t>& a_gs_ms_ks_lengths,
612 const std::vector<index_t>& b_gs_ns_ks_lengths,
613 const std::array<std::vector<index_t>, NumDTensor>& ds_gs_ms_ns_lengths,
614 const std::vector<index_t>& e_gs_ms_ns_lengths,
615 const std::vector<index_t>& a_gs_ms_ks_strides,
616 const std::vector<index_t>& b_gs_ns_ks_strides,
617 const std::array<std::vector<index_t>, NumDTensor>& ds_gs_ms_ns_strides,
618 const std::vector<index_t>& e_gs_ms_ns_strides,
619 index_t M01,
620 index_t N01,
621 AElementwiseOperation a_element_op,
622 BElementwiseOperation b_element_op,
623 CDEElementwiseOperation cde_element_op)
624 : p_a_grid_{static_cast<const ADataType*>(p_a_grid)},
625 p_b_grid_{static_cast<const BDataType*>(p_b_grid)},
626 p_ds_grid_{},
627 p_e_grid_{static_cast<EDataType*>(p_e_grid)},
628 a_grid_desc_{},
629 b_grid_desc_{},
633 DeviceOp::MakeDsGridDescriptor_G_M_N(ds_gs_ms_ns_lengths, ds_gs_ms_ns_strides)},
635 DeviceOp::MakeEGridDescriptor_G_M_N(e_gs_ms_ns_lengths, e_gs_ms_ns_strides)},
639 M01_{M01},
640 N01_{N01},
641 a_element_op_{a_element_op},
642 b_element_op_{b_element_op},
643 cde_element_op_{cde_element_op},
644 a_mz_stride_{},
645 a_kz_stride_{},
646 b_nz_stride_{},
647 b_kz_stride_{},
649 e_nz_stride_{},
650 a_batch_stride_{a_gs_ms_ks_strides[NumDimG - 1]},
651 b_batch_stride_{b_gs_ns_ks_strides[NumDimG - 1]},
654 {
655 static_for<0, NumDTensor, 1>{}([&](auto i) {
656 using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
657
658 // D pointer
659 p_ds_grid_(i) = static_cast<const DDataType*>(p_ds_grid[i]);
660 });
661
662 a_grid_desc_ = DeviceOp::MakeAGridDescriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides);
663 b_grid_desc_ = DeviceOp::MakeBGridDescriptor(b_gs_ns_ks_lengths, b_gs_ns_ks_strides);
664
666 DeviceOp::MakeDsGridDescriptor_M_N(ds_gs_ms_ns_lengths, ds_gs_ms_ns_strides);
667
669 DeviceOp::MakeEGridDescriptor_M_N(e_gs_ms_ns_lengths, e_gs_ms_ns_strides);
670
672
676
679
680 // for sanity check of vector memory access
681 a_mz_stride_ = a_gs_ms_ks_strides[NumDimG + NumDimM - 1];
682 a_kz_stride_ = a_gs_ms_ks_strides[NumDimG + NumDimM + NumDimK - 1];
683 b_nz_stride_ = b_gs_ns_ks_strides[NumDimG + NumDimN - 1];
684 b_kz_stride_ = b_gs_ns_ks_strides[NumDimG + NumDimN + NumDimK - 1];
685
686 for(index_t i = 0; i < NumDTensor; ++i)
687 {
688 ds_nz_stride_[i] = ds_gs_ms_ns_strides[i][NumDimG + NumDimM + NumDimN - 1];
689 }
690
691 e_nz_stride_ = e_gs_ms_ns_strides[NumDimG + NumDimM + NumDimN - 1];
692 }
693
694 // Pointers
695 const ADataType* p_a_grid_;
696 const BDataType* p_b_grid_;
698 EDataType* p_e_grid_;
699
700 // Tensor Descriptors
707
712
713 // Block to Tile mapping
715
716 // Idle
719
720 // ElementwiseOp
721 AElementwiseOperation a_element_op_;
722 BElementwiseOperation b_element_op_;
723 CDEElementwiseOperation cde_element_op_;
724
725 // Strides for the last M/N/K dimensions of A/B/Ds/E
726 // for sanity check of vector load/store
731 std::array<index_t, NumDTensor> ds_nz_stride_;
734
737
738 // Batch Offset
740
741 // for checking vector load/store
742 // index_t MRaw_;
743 // index_t NRaw_;
744 // index_t KRaw_;
745 };
746
747 // Invoker
748 struct Invoker : public BaseInvoker
749 {
751
752 float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
753 {
754 const index_t G = arg.e_grid_desc_g_m_n_.GetLength(I0);
755
756 const index_t grid_size =
757 arg.block_2_ctile_map_.CalculateGridSize(arg.e_grid_desc_m_n_) * G;
758
759 const auto K = [&]() {
760 if constexpr(AEnableLds)
761 {
762 return arg.a_grid_desc_.GetLength(I0) * arg.a_grid_desc_.GetLength(I2);
763 }
764 else
765 {
766 return arg.a_grid_desc_.GetLength(I0) * arg.a_grid_desc_.GetLength(I3) *
767 arg.a_grid_desc_.GetLength(I4) * arg.a_grid_desc_.GetLength(I6);
768 }
769 }();
770
771 auto launch_kernel = [&](auto has_main_k_block_loop) {
772 constexpr bool has_main_loop = has_main_k_block_loop.value;
773
776 ADataType,
777 BDataType,
779 EDataType,
784 AElementwiseOperation,
785 BElementwiseOperation,
786 CDEElementwiseOperation,
787 ComputePtrOffsetOfStridedBatch,
789 has_main_loop>;
790
791 return launch_and_time_kernel(stream_config,
792 kernel,
793 dim3(grid_size),
794 dim3(BlockSize),
795 0,
796 arg.p_a_grid_,
797 arg.p_b_grid_,
798 arg.p_ds_grid_,
799 arg.p_e_grid_,
800 G,
801 arg.a_grid_desc_,
802 arg.b_grid_desc_,
805 arg.a_element_op_,
806 arg.b_element_op_,
807 arg.cde_element_op_,
810 };
811
813 {
814 return launch_kernel(integral_constant<bool, true>{});
815 }
816 else
817 {
818 return launch_kernel(integral_constant<bool, false>{});
819 }
820 }
821
822 // polymorphic
823 float Run(const BaseArgument* p_arg,
824 const StreamConfig& stream_config = StreamConfig{}) override
825 {
826 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
827 }
828 };
829
830 static constexpr bool IsValidCompilationParameter()
831 {
832 // TODO: properly implement this check
833 return true;
834 }
835
836 static bool IsSupportedArgument(const Argument& arg)
837 {
839 {
841 {
842 printf("DeviceOp: Arch check failure\n");
843 return false;
844 }
845 }
846 else
847 {
848 return false;
849 }
850
852 arg.b_grid_desc_,
856 {
857 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
858 {
859 printf("GridwiseOp: Validity check failure\n");
860 }
861 return false;
862 }
863
864 // check vector access
865 static_assert((ABlockTransferSrcVectorDim == 1 || ABlockTransferSrcVectorDim == 2) &&
866 (BBlockTransferSrcVectorDim == 1 || BBlockTransferSrcVectorDim == 2),
867 "wrong!");
868
869 // vector memory access of A: could be on M or AK1 dimension
870 if constexpr(ABlockTransferSrcVectorDim == 1)
871 {
872 if(!(arg.a_mz_stride_ == 1 &&
873 arg.a_grid_desc_.GetLength(I1) % ABlockTransferSrcScalarPerVector == 0))
874 {
875 printf("DeviceOp: Vector Access A-m check failure\n");
876 return false;
877 }
878 }
879 else
880 {
881 if(!(arg.a_kz_stride_ == 1))
882 {
883 index_t LastK =
884 AEnableLds ? arg.a_grid_desc_.GetLength(I2) : arg.a_grid_desc_.GetLength(I6);
885 if(LastK % ABlockTransferSrcScalarPerVector == 0)
886 {
887 printf("DeviceOp: Vector Access A-k check failure\n");
888 return false;
889 }
890 }
891 }
892
893 // vector memory access of B: could be on N or BK1 dimension
894 if constexpr(BBlockTransferSrcVectorDim == 1)
895 {
896 if(!(arg.b_nz_stride_ == 1 &&
897 arg.b_grid_desc_.GetLength(I1) % BBlockTransferSrcScalarPerVector == 0))
898 {
899 printf("DeviceOp: Vector Access B-n check failure\n");
900 return false;
901 }
902 }
903 else
904 {
905 if(!(arg.b_kz_stride_ == 1 &&
906 arg.b_grid_desc_.GetLength(I2) % BBlockTransferSrcScalarPerVector == 0))
907 {
908 printf("DeviceOp: Vector Access B-k check failure\n");
909 return false;
910 }
911 }
912
913 // vector memory access of Ds: always on NPerBlock dimension
914 bool valid_d_access = true;
915
916 static_for<0, NumDTensor, 1>{}([&](auto i) {
917 if(!(arg.ds_nz_stride_[i] == 1 &&
919 CDEShuffleBlockTransferScalarPerVector_NPerBlock ==
920 0))
921 {
922 printf("DeviceOp: Vector Access D-n check failure\n");
923 valid_d_access = false;
924 }
925 });
926
927 if(valid_d_access == false)
928 {
929 return false;
930 }
931
932 // vector memory access of E: always on NPerBlock dimension
933 if(!((arg.e_nz_stride_ == 1 &&
935 CDEShuffleBlockTransferScalarPerVector_NPerBlock ==
936 0) ||
937 CDEShuffleBlockTransferScalarPerVector_NPerBlock == 1))
938 {
939 printf("DeviceOp: Vector Access E-n check failure\n");
940 return false;
941 }
942
943 return true;
944 }
945
946 // polymorphic
947 bool IsSupportedArgument(const BaseArgument* p_arg) override
948 {
949 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
950 }
951
952 static auto
953 MakeArgument(const void* p_a,
954 const void* p_b,
955 std::array<const void*, NumDTensor> p_ds,
956 void* p_e,
957 const std::vector<index_t>& a_gs_ms_ks_lengths,
958 const std::vector<index_t>& a_gs_ms_ks_strides,
959 const std::vector<index_t>& b_gs_ns_ks_lengths,
960 const std::vector<index_t>& b_gs_ns_ks_strides,
961 const std::array<std::vector<index_t>, NumDTensor>& ds_gs_ms_ns_lengths,
962 const std::array<std::vector<index_t>, NumDTensor>& ds_gs_ms_ns_strides,
963 const std::vector<index_t>& e_gs_ms_ns_lengths,
964 const std::vector<index_t>& e_gs_ms_ns_strides,
965 AElementwiseOperation a_element_op,
966 BElementwiseOperation b_element_op,
967 CDEElementwiseOperation cde_element_op)
968 {
969 return Argument{p_a,
970 p_b,
971 p_ds,
972 p_e,
973 a_gs_ms_ks_lengths,
974 b_gs_ns_ks_lengths,
975 ds_gs_ms_ns_lengths,
976 e_gs_ms_ns_lengths,
977 a_gs_ms_ks_strides,
978 b_gs_ns_ks_strides,
979 ds_gs_ms_ns_strides,
980 e_gs_ms_ns_strides,
981 1,
982 1,
983 a_element_op,
984 b_element_op,
985 cde_element_op};
986 }
987
988 // polymorphic
989 std::unique_ptr<BaseArgument>
990 MakeArgumentPointer(const void* p_a,
991 const void* p_b,
992 std::array<const void*, NumDTensor> p_ds,
993 void* p_e,
994 const std::vector<index_t>& a_gs_ms_ks_lengths,
995 const std::vector<index_t>& a_gs_ms_ks_strides,
996 const std::vector<index_t>& b_gs_ns_ks_lengths,
997 const std::vector<index_t>& b_gs_ns_ks_strides,
998 const std::array<std::vector<index_t>, NumDTensor>& ds_gs_ms_ns_lengths,
999 const std::array<std::vector<index_t>, NumDTensor>& ds_gs_ms_ns_strides,
1000 const std::vector<index_t>& e_gs_ms_ns_lengths,
1001 const std::vector<index_t>& e_gs_ms_ns_strides,
1002 AElementwiseOperation a_element_op,
1003 BElementwiseOperation b_element_op,
1004 CDEElementwiseOperation cde_element_op) override
1005 {
1006 return std::make_unique<Argument>(p_a,
1007 p_b,
1008 p_ds,
1009 p_e,
1010 a_gs_ms_ks_lengths,
1011 b_gs_ns_ks_lengths,
1012 ds_gs_ms_ns_lengths,
1013 e_gs_ms_ns_lengths,
1014 a_gs_ms_ks_strides,
1015 b_gs_ns_ks_strides,
1016 ds_gs_ms_ns_strides,
1017 e_gs_ms_ns_strides,
1018 1,
1019 1,
1020 a_element_op,
1021 b_element_op,
1022 cde_element_op);
1023 }
1024
1025 static auto MakeInvoker() { return Invoker{}; }
1026
1027 // polymorphic
1028 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
1029 {
1030 return std::make_unique<Invoker>(Invoker{});
1031 }
1032
1033 // polymorphic
1034 std::string GetTypeString() const override
1035 {
1036 auto str = std::stringstream();
1037
1038 std::map<LoopScheduler, std::string> LoopSchedToString{
1039 {LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}};
1040
1041 std::map<PipelineVersion, std::string> PipelineVersionToString{{PipelineVersion::v1, "v1"},
1042 {PipelineVersion::v2, "v2"}};
1043
1044 // clang-format off
1045 str << "DeviceBatchedContractionMultipleD_Wmma_CShuffle"
1046 << "<"
1047 << BlockSize << ", "
1048 << MPerBlock << ", "
1049 << NPerBlock << ", "
1050 << KPerBlock << ", "
1051 << K1 << ", "
1052 << MPerWmma << ", "
1053 << NPerWmma << ", "
1054 << MRepeat << ", "
1055 << NRepeat
1056 << ">"
1057 << " AEnableLds: "
1058 << AEnableLds << ", "
1059 << "BEnableLds: "
1060 << BEnableLds << ", "
1061 << "NumPrefetch: "
1062 << NumPrefetch << ", "
1063 << "LoopScheduler: "
1064 << LoopSchedToString[LoopSched] << ", "
1065 << "PipelineVersion: "
1066 << PipelineVersionToString[PipelineVer];
1067 // clang-format on
1068
1069 return str.str();
1070 }
1071};
1072
1073} // namespace device
1074} // namespace tensor_operation
1075} // namespace ck
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
Definition convolution_backward_data_specialization.hpp:8
TensorSpecialization
Definition tensor_specialization.hpp:11
@ Packed
Definition tensor_specialization.hpp:13
GemmSpecialization
Definition gemm_specialization.hpp:11
Definition convolution_backward_data_specialization.hpp:7
CK_TILE_HOST float launch_kernel(const stream_config &s, Callables &&... callables)
Definition tile/host/kernel_launch.hpp:173
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
@ Set
Definition ck.hpp:278
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
integral_constant< index_t, N > Number
Definition number.hpp:12
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
bool is_gfx12_supported()
Definition host_utility/device_prop.hpp:55
bool EnvIsEnabled(EnvVar)
Definition utility/env.hpp:140
__host__ __device__ constexpr auto get_container_subset(const Array< T, N > &arr, Sequence< Is... >)
Definition utility/container_helper.hpp:346
constexpr bool is_same_v
Definition type.hpp:283
__host__ __device__ constexpr auto container_reduce(const Container &x, Reduce reduce, Init init, Number< IBegin >=Number< 0 >{}, Number< IEnd >=Number< Container::Size()>{}, Number< IStep >=Number< 1 >{})
Definition utility/container_helper.hpp:111
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__global__ void kernel_contraction_multiple_d_wmma_cshuffle(const ADataType *__restrict__ p_a_grid, const BDataType *__restrict__ p_b_grid, DsPointer p_ds_grid, EDataType *__restrict__ p_e_grid, const index_t batch_count, const AGridDesc a_grid_desc, const BGridDesc b_grid_desc, const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock, const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CDEElementwiseOperation cde_element_op, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, const Block2CTileMap block_2_etile_map)
Definition gridwise_gemm_multiple_d_wmma_cshuffle.hpp:133
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
LoopScheduler
Definition loop_scheduler.hpp:15
@ Default
Definition loop_scheduler.hpp:16
@ Interwave
Definition loop_scheduler.hpp:17
int64_t long_index_t
Definition ck.hpp:300
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
PipelineVersion
Definition gridwise_gemm_pipeline_selector.hpp:18
@ v2
Definition gridwise_gemm_pipeline_selector.hpp:20
@ v1
Definition gridwise_gemm_pipeline_selector.hpp:19
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
bool is_gfx11_supported()
Definition host_utility/device_prop.hpp:60
constexpr LoopScheduler make_default_loop_scheduler()
Definition loop_scheduler.hpp:20
Definition ck/stream_config.hpp:10
Definition gridwise_gemm_multiple_d_wmma_cshuffle.hpp:326
Definition utility/sequence.hpp:43
Definition utility/sequence.hpp:256
typename conditional< kHasContent, type0, type1 >::type type
Definition utility/sequence.hpp:271
Definition utility/math.hpp:34
Definition functional2.hpp:33
Definition device_base.hpp:197
__host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:510
__host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:520
__host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:515
ComputePtrOffsetOfStridedBatch(index_t batch_stride_A, index_t batch_stride_B, DsGridDesc_G_M_N ds_grid_desc_g_m_n, EGridDesc_G_M_N e_grid_desc_g_m_n)
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:499
__host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:532
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:606
index_t a_kz_stride_
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:728
index_t e_nz_stride_
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:733
AGridDesc a_grid_desc_
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:701
index_t b_kz_stride_
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:730
const ADataType * p_a_grid_
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:695
EGridDesc_M_N e_grid_desc_m_n_
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:704
index_t a_mz_stride_
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:727
std::array< index_t, NumDTensor > ds_nz_stride_
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:731
EDataType * p_e_grid_
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:698
BGridDesc b_grid_desc_
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:702
GridwiseOp::DsGridPointer p_ds_grid_
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:697
index_t a_batch_stride_
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:735
BElementwiseOperation b_element_op_
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:722
index_t M01_
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:717
Argument(const void *p_a_grid, const void *p_b_grid, std::array< const void *, NumDTensor > p_ds_grid, void *p_e_grid, const std::vector< index_t > &a_gs_ms_ks_lengths, const std::vector< index_t > &b_gs_ns_ks_lengths, const std::array< std::vector< index_t >, NumDTensor > &ds_gs_ms_ns_lengths, const std::vector< index_t > &e_gs_ms_ns_lengths, const std::vector< index_t > &a_gs_ms_ks_strides, const std::vector< index_t > &b_gs_ns_ks_strides, const std::array< std::vector< index_t >, NumDTensor > &ds_gs_ms_ns_strides, const std::vector< index_t > &e_gs_ms_ns_strides, index_t M01, index_t N01, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:607
const BDataType * p_b_grid_
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:696
index_t b_batch_stride_
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:736
index_t N01_
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:718
AElementwiseOperation a_element_op_
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:721
GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:711
DsGridDesc_G_M_N ds_grid_desc_g_m_n_
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:705
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:739
index_t e_mz_stride_
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:732
CDEElementwiseOperation cde_element_op_
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:723
EGridDesc_G_M_N e_grid_desc_g_m_n_
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:706
DsGridDesc_M_N ds_grid_desc_m_n_
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:703
GridwiseOp::DefaultBlock2CTileMap block_2_ctile_map_
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:714
GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:709
index_t b_nz_stride_
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:729
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:749
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:752
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:823
DeviceOp::Argument Argument
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:750
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:119
decltype(MakeEGridDescriptor_G_M_N({}, {})) EGridDesc_G_M_N
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:495
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_e, const std::vector< index_t > &a_gs_ms_ks_lengths, const std::vector< index_t > &a_gs_ms_ks_strides, const std::vector< index_t > &b_gs_ns_ks_lengths, const std::vector< index_t > &b_gs_ns_ks_strides, const std::array< std::vector< index_t >, NumDTensor > &ds_gs_ms_ns_lengths, const std::array< std::vector< index_t >, NumDTensor > &ds_gs_ms_ns_strides, const std::vector< index_t > &e_gs_ms_ns_lengths, const std::vector< index_t > &e_gs_ms_ns_strides, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op) override
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:990
static constexpr auto I6
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:129
static auto MakeBGridDescriptor(const std::vector< index_t > &b_gs_ns_ks_lengths_vec, const std::vector< index_t > &b_gs_ns_ks_strides_vec)
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:248
static constexpr auto I0
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:123
decltype(DeviceOp::MakeAGridDescriptor({}, {})) AGridDesc
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:545
static constexpr auto I2
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:125
decltype(MakeEGridDescriptor_M_N({}, {})) EGridDesc_M_N
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:492
remove_cvref_t< decltype(MakeDsGridDescriptor_M_N({}, {}))> DsGridDesc_M_N
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:491
static constexpr index_t NumDTensor
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:121
static constexpr auto I1
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:124
decltype(DeviceOp::MakeBGridDescriptor({}, {})) BGridDesc
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:546
static constexpr auto WmmaK
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:135
GridwiseGemmMultipleD_Wmma< ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AGridDesc, BGridDesc, DsGridDesc_M_N, EGridDesc_M_N, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, K1, MRepeat, NRepeat, BlockSize, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, AEnableLds, ABlockLdsAddExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, BEnableLds, BBlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVector_NPerBlock, NumPrefetch, LoopSched, PipelineVer > GridwiseOp
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:549
static constexpr auto BEnableLds_manu
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:147
static constexpr auto AEnableLds_auto
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:140
static auto MakeInvoker()
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:1025
static constexpr auto matrix_padder
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:152
static constexpr auto AEnableLds
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:149
static constexpr auto BEnableLds_auto
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:142
static auto MakeDsGridDescriptor_M_N(const std::array< std::vector< index_t >, NumDTensor > &ds_gs_ms_ns_lengths_vec, const std::array< std::vector< index_t >, NumDTensor > &ds_gs_ms_ns_strides_vec)
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:466
static constexpr auto K1Number
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:131
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:1028
DeviceBatchedContractionMultipleD_Wmma_CShuffle DeviceOp
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:120
std::string GetTypeString() const override
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:1034
static auto MakeEGridDescriptor_G_M_N(const std::vector< index_t > &e_gs_ms_ns_lengths_vec, const std::vector< index_t > &e_gs_ms_ns_strides_vec)
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:396
static constexpr auto MWaves
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:133
static auto MakeEGridDescriptor_M_N(const std::vector< index_t > &e_gs_ms_ns_lengths_vec, const std::vector< index_t > &e_gs_ms_ns_strides_vec)
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:340
remove_cvref_t< decltype(MakeDsGridDescriptor_G_M_N({}, {}))> DsGridDesc_G_M_N
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:494
static constexpr bool IsValidCompilationParameter()
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:830
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:947
static constexpr auto I5
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:128
static auto MakeDsGridDescriptor_G_M_N(const std::array< std::vector< index_t >, NumDTensor > &ds_gs_ms_ns_lengths_vec, const std::array< std::vector< index_t >, NumDTensor > &ds_gs_ms_ns_strides_vec)
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:478
static constexpr auto NWaves
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:134
static constexpr auto I4
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:127
static constexpr auto MaxVectorLoadA
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:137
static constexpr auto AEnableLds_manu
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:146
static constexpr auto BEnableLds
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:150
static auto MakeAGridDescriptor(const std::vector< index_t > &a_gs_ms_ks_lengths_vec, const std::vector< index_t > &a_gs_ms_ks_strides_vec)
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:156
static constexpr auto I3
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:126
static constexpr auto MaxVectorLoadB
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:138
static auto MakeArgument(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_e, const std::vector< index_t > &a_gs_ms_ks_lengths, const std::vector< index_t > &a_gs_ms_ks_strides, const std::vector< index_t > &b_gs_ns_ks_lengths, const std::vector< index_t > &b_gs_ns_ks_strides, const std::array< std::vector< index_t >, NumDTensor > &ds_gs_ms_ns_lengths, const std::array< std::vector< index_t >, NumDTensor > &ds_gs_ms_ns_strides, const std::vector< index_t > &e_gs_ms_ns_lengths, const std::vector< index_t > &e_gs_ms_ns_strides, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:953
static bool IsSupportedArgument(const Argument &arg)
Definition device_batched_contraction_multiple_d_wmma_cshuffle.hpp:836
Definition device_batched_contraction_multiple_d.hpp:39
Definition matrix_padder.hpp:180
#define CK_ENV(name)
Definition utility/env.hpp:129