batched_contraction_kernel.hpp Source File

batched_contraction_kernel.hpp Source File#

Composable Kernel: batched_contraction_kernel.hpp Source File
batched_contraction_kernel.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
9
81
82namespace ck_tile {
83
98template <ck_tile::index_t NumDTensor = 0>
100{
118 const void* a_ptr_,
119 const void* b_ptr_,
120 const std::array<const void*, NumDTensor>& ds_ptr_,
121 void* e_ptr_,
122 ck_tile::index_t k_batch_,
123 const std::vector<ck_tile::index_t>& A_dims_, // [G0, G1, ..., M0, M1, ... , K0, K1, ...]
124 const std::vector<ck_tile::index_t>& B_dims_, // [G0, G1, ..., N0, N1, ... , K0, K1, ...]
125 const std::array<std::vector<ck_tile::index_t>, NumDTensor>&
126 Ds_dims_, // [G0, G1, ..., M0, M1, ... , N0, N1, ...][NumDTensor]
127 const std::vector<ck_tile::index_t>& E_dims_, // [G0, G1, ..., M0, M1, ... , N0, N1, ...]
128
129 const std::vector<ck_tile::index_t>& A_strides_, // [G0, G1, ..., M0, M1, ...,K0, K1, ...]
130 const std::vector<ck_tile::index_t>& B_strides_, // [G0, G1, ..., N0, N1, ...,K0, K1, ...]
131 const std::array<std::vector<ck_tile::index_t>, NumDTensor>&
132 Ds_strides_, // [G0, G1, ..., M0, M1, ...,N0, N1, ...]
133 const std::vector<ck_tile::index_t>&
134 E_strides_) // [G0, G1, ..., M0, M1, ...,N0, N1, ...][NumDTensor]
135
136 : a_ptr(a_ptr_),
137 b_ptr(b_ptr_),
138 ds_ptr(ds_ptr_),
139 e_ptr(e_ptr_),
140 k_batch(k_batch_),
141 A_dims(A_dims_),
142 B_dims(B_dims_),
143 Ds_dims(Ds_dims_),
144 E_dims(E_dims_),
145 A_strides(A_strides_),
146 B_strides(B_strides_),
147 Ds_strides(Ds_strides_),
148 E_strides(E_strides_)
149 {
150 }
151
152 const void* a_ptr;
153 const void* b_ptr;
154 std::array<const void*, NumDTensor> ds_ptr;
155 void* e_ptr;
157 const std::vector<ck_tile::index_t>
159 const std::vector<ck_tile::index_t>
161 const std::array<std::vector<ck_tile::index_t>, NumDTensor>
163 const std::vector<ck_tile::index_t>
165 const std::vector<ck_tile::index_t>
167 const std::vector<ck_tile::index_t>
169 const std::array<std::vector<ck_tile::index_t>, NumDTensor>
171 const std::vector<ck_tile::index_t>
173};
174
182
183template <ck_tile::index_t NumDimG,
184 ck_tile::index_t NumDimM,
185 ck_tile::index_t NumDimN,
186 ck_tile::index_t NumDimK,
187 ck_tile::index_t NumDTensor = 0>
219
232
233template <typename Problem_,
234 typename TilePartitioner_,
235 typename GemmPipeline_,
236 typename EpiloguePipeline_>
238{
239 // Type aliases for cleaner code and better readability
241 using ADataType =
243 using BDataType =
248 using EDataType =
250
251 // Compile-time dimension constants extracted from problem specification
252 static constexpr ck_tile::index_t NumDimG = Problem::NumDimG;
253 static constexpr ck_tile::index_t NumDimM =
254 Problem::NumDimM;
255 static constexpr ck_tile::index_t NumDimN =
256 Problem::NumDimN;
257 static constexpr ck_tile::index_t NumDimK =
258 Problem::NumDimK;
259 static constexpr ck_tile::index_t NumDTensor =
260 Problem::NumDTensor;
261
262 // Pipeline and partitioning strategy types
269
270 // Underlying GEMM kernel that performs the actual computation
273
274 static constexpr ck_tile::index_t kBlockSize =
276
281
284 CK_TILE_HOST static constexpr auto GetKernelName() { return "batched_contraction_kernel"; }
285
290 CK_TILE_HOST static constexpr bool IsSupportedArguments(const KernelArgs& kargs)
291 {
292 typename UniversalGemmKernel::KernelArgs gemm_kargs{{kargs.a_ptr},
293 {kargs.b_ptr},
294 kargs.ds_ptr,
295 kargs.e_ptr,
296 kargs.M_total,
297 kargs.N_total,
298 kargs.K_total,
299 {kargs.stride_A},
300 {kargs.stride_B},
301 kargs.stride_Ds,
302 kargs.stride_E,
303 kargs.k_batch};
304
305 return UniversalGemmKernel::IsSupportedArgument(gemm_kargs) && kargs.G_total > 0;
306 }
307
312 {
314 }
315
318 CK_TILE_HOST static constexpr auto GetBlockSize()
319 {
321 }
322
323 CK_TILE_HOST static constexpr auto GridSize(const KernelArgs& kargs)
324 {
325 return dim3(
326 TilePartitioner::GridSize(kargs.M_total, kargs.N_total), kargs.G_total, kargs.k_batch);
327 }
328
329 CK_TILE_HOST static constexpr KernelArgs
331 {
332 const auto expected_A_dims = NumDimG + NumDimM + NumDimK;
333 const auto expected_B_dims = NumDimG + NumDimN + NumDimK;
334 const auto expected_E_dims = NumDimG + NumDimM + NumDimN;
335
336 if(host_args.A_dims.size() != expected_A_dims ||
337 host_args.A_strides.size() != expected_A_dims)
338 {
339 throw std::invalid_argument("A dimension size mismatch");
340 }
341 if(host_args.B_dims.size() != expected_B_dims ||
342 host_args.B_strides.size() != expected_B_dims)
343 {
344 throw std::invalid_argument("B dimension size mismatch");
345 }
346 if(host_args.E_dims.size() != expected_E_dims ||
347 host_args.E_strides.size() != expected_E_dims)
348 {
349 throw std::invalid_argument("E dimension size mismatch");
350 }
351
352 for(ck_tile::index_t d = 0; d < NumDTensor; ++d)
353 {
354 if(host_args.Ds_dims[d].size() != expected_E_dims ||
355 host_args.Ds_strides[d].size() != expected_E_dims)
356 {
357 throw std::invalid_argument("D dimension size mismatch");
358 }
359 }
360
361 KernelArgs kargs;
362 kargs.a_ptr = host_args.a_ptr;
363 kargs.b_ptr = host_args.b_ptr;
364 kargs.ds_ptr = host_args.ds_ptr;
365 kargs.e_ptr = host_args.e_ptr;
366 kargs.k_batch = host_args.k_batch;
367
368 // Validate and set G dimensions (must be identical across all tensors)
369 for(ck_tile::index_t i = 0; i < NumDimG; ++i)
370 {
371 // All tensors must have same G dimensions for valid contraction
372 if(host_args.A_dims[i] != host_args.B_dims[i] ||
373 host_args.A_dims[i] != host_args.E_dims[i])
374 {
375 throw std::invalid_argument(
376 "All tensors must have identical G dimensions for valid contraction");
377 }
378
379 // Store G dimensions (same for all tensors)
380 kargs.G_dims[i] = host_args.A_dims[i];
381 }
382
383 // Set batch strides from the stride of last G dimension
384 kargs.batch_stride_A = host_args.A_strides[NumDimG - 1];
385 kargs.batch_stride_B = host_args.B_strides[NumDimG - 1];
386 kargs.batch_stride_E = host_args.E_strides[NumDimG - 1];
387
388 for(ck_tile::index_t i = 0; i < NumDimM; ++i)
389 {
390 kargs.M_dims[i] = host_args.A_dims[NumDimG + i];
391 if(kargs.M_dims[i] != host_args.E_dims[NumDimG + i])
392 {
393 throw std::invalid_argument("M dimension mismatch between A and E tensors");
394 }
395 }
396 for(ck_tile::index_t i = 0; i < NumDimN; ++i)
397 {
398 kargs.N_dims[i] = host_args.B_dims[NumDimG + i];
399 if(kargs.N_dims[i] != host_args.E_dims[NumDimG + NumDimM + i])
400 {
401 throw std::invalid_argument("N dimension mismatch between B and E tensors");
402 }
403 }
404 for(ck_tile::index_t i = 0; i < NumDimK; ++i)
405 {
406 kargs.K_dims[i] = host_args.A_dims[NumDimG + NumDimM + i];
407 if(kargs.K_dims[i] != host_args.B_dims[NumDimG + NumDimN + i])
408 {
409 throw std::invalid_argument("K dimension mismatch between A and B tensors");
410 }
411 }
412
413 // Calculate total dimensions from individual dimension arrays
414 kargs.G_total = 1;
415 for(ck_tile::index_t i = 0; i < NumDimG; ++i)
416 {
417 kargs.G_total *= kargs.G_dims[i];
418 }
419
420 kargs.M_total = 1;
421 for(ck_tile::index_t i = 0; i < NumDimM; ++i)
422 {
423 kargs.M_total *= kargs.M_dims[i];
424 }
425
426 kargs.N_total = 1;
427 for(ck_tile::index_t i = 0; i < NumDimN; ++i)
428 {
429 kargs.N_total *= kargs.N_dims[i];
430 }
431
432 kargs.K_total = 1;
433 for(ck_tile::index_t i = 0; i < NumDimK; ++i)
434 {
435 kargs.K_total *= kargs.K_dims[i];
436 }
437
438 kargs.stride_A = kargs.K_total;
439 kargs.stride_B = kargs.K_total;
440 kargs.stride_E = kargs.N_total;
441
442 // Validate D tensors have same G dimensions and set their batch strides
443 for(ck_tile::index_t d = 0; d < NumDTensor; ++d)
444 {
445 for(ck_tile::index_t i = 0; i < NumDimG; ++i)
446 {
447 if(host_args.Ds_dims[d][i] != host_args.A_dims[i])
448 {
449 throw std::invalid_argument(
450 "D tensor G dimensions must match A/B/E tensor G dimensions");
451 }
452 }
453 // Set batch stride for D tensor
454 kargs.batch_stride_Ds[d] = host_args.Ds_strides[d][NumDimG - 1];
455 kargs.stride_Ds[d] = kargs.N_total; // D tensors same shape as E
456 }
457
458 return kargs;
459 }
460
461 CK_TILE_DEVICE void operator()(const KernelArgs& kargs) const
462 {
463
464 const auto [iM, iN] =
465 TilePartitioner{kargs.M_total, kargs.N_total}.GetOutputTileIndex(blockIdx.x);
466 const ck_tile::index_t i_m =
467 __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
468 const ck_tile::index_t i_n =
469 __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
470
471 const auto i_batch_flat = __builtin_amdgcn_readfirstlane(blockIdx.y);
472 const auto i_splitk = __builtin_amdgcn_readfirstlane(blockIdx.z);
473
474 // Calculate batch offsets for each tensor
475 const auto batch_offset_A = i_batch_flat * kargs.batch_stride_A;
476 const auto batch_offset_B = i_batch_flat * kargs.batch_stride_B;
477 const auto batch_offset_E = i_batch_flat * kargs.batch_stride_E;
478
479 const ADataType* a_ptr = static_cast<const ADataType*>(kargs.a_ptr) + batch_offset_A;
480 const BDataType* b_ptr = static_cast<const BDataType*>(kargs.b_ptr) + batch_offset_B;
481 EDataType* e_ptr = static_cast<EDataType*>(kargs.e_ptr) + batch_offset_E;
482
483 std::array<const void*, NumDTensor> ds_batch_ptr;
484 static_for<0, NumDTensor, 1>{}([&](auto i) {
485 using DDataType = typename std::tuple_element<i.value, DsDataType>::type;
486 const auto batch_offset_D = i_batch_flat * kargs.batch_stride_Ds[i];
487 ds_batch_ptr[i] = static_cast<const DDataType*>(kargs.ds_ptr[i]) + batch_offset_D;
488 });
489
490 typename UniversalGemmKernel::KernelArgs gemm_kargs{{a_ptr},
491 {b_ptr},
492 ds_batch_ptr,
493 e_ptr,
494 kargs.M_total,
495 kargs.N_total,
496 kargs.K_total,
497 {kargs.stride_A},
498 {kargs.stride_B},
499 kargs.stride_Ds,
500 kargs.stride_E,
501 kargs.k_batch};
502
503 const typename UniversalGemmKernel::SplitKBatchOffset splitk_batch_offset(gemm_kargs,
504 i_splitk);
505
506 const ADataType* a_ptr_final = a_ptr + splitk_batch_offset.as_k_split_offset[0];
507 const BDataType* b_ptr_final = b_ptr + splitk_batch_offset.bs_k_split_offset[0];
508 __shared__ char smem_ptr[GetSmemSize()];
509
510 UniversalGemmKernel::RunGemm({a_ptr_final},
511 {b_ptr_final},
512 ds_batch_ptr,
513 e_ptr,
514 smem_ptr,
515 gemm_kargs,
516 splitk_batch_offset,
517 i_m,
518 i_n);
519 }
520};
521
522} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST
Definition config.hpp:40
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
int32_t index_t
Definition integer.hpp:9
Definition batched_contraction_kernel.hpp:100
const std::array< std::vector< ck_tile::index_t >, NumDTensor > Ds_dims
Dimension vectors for D tensors: [G0, G1, ..., M0, M1, ..., N0, N1, ...].
Definition batched_contraction_kernel.hpp:162
const void * b_ptr
Pointer to input tensor B.
Definition batched_contraction_kernel.hpp:153
const std::vector< ck_tile::index_t > E_strides
Stride vector for tensor E: [G0, G1, ..., M0, M1, ..., N0, N1, ...].
Definition batched_contraction_kernel.hpp:172
const std::vector< ck_tile::index_t > E_dims
Dimension vector for tensor E: [G0, G1, ..., M0, M1, ..., N0, N1, ...].
Definition batched_contraction_kernel.hpp:164
std::array< const void *, NumDTensor > ds_ptr
Array of pointers to auxiliary input tensors D.
Definition batched_contraction_kernel.hpp:154
const std::vector< ck_tile::index_t > B_dims
Dimension vector for tensor B: [G0, G1, ..., N0, N1, ..., K0, K1, ...].
Definition batched_contraction_kernel.hpp:160
void * e_ptr
Pointer to output tensor E.
Definition batched_contraction_kernel.hpp:155
const std::vector< ck_tile::index_t > A_dims
Dimension vector for tensor A: [G0, G1, ..., M0, M1, ..., K0, K1, ...].
Definition batched_contraction_kernel.hpp:158
const std::array< std::vector< ck_tile::index_t >, NumDTensor > Ds_strides
Stride vectors for D tensors: [G0, G1, ..., M0, M1, ..., N0, N1, ...].
Definition batched_contraction_kernel.hpp:170
const std::vector< ck_tile::index_t > A_strides
Stride vector for tensor A: [G0, G1, ..., M0, M1, ..., K0, K1, ...].
Definition batched_contraction_kernel.hpp:166
ck_tile::index_t k_batch
Number of k-splits for split-K batching.
Definition batched_contraction_kernel.hpp:156
CK_TILE_HOST BatchedContractionHostArgs(const void *a_ptr_, const void *b_ptr_, const std::array< const void *, NumDTensor > &ds_ptr_, void *e_ptr_, ck_tile::index_t k_batch_, const std::vector< ck_tile::index_t > &A_dims_, const std::vector< ck_tile::index_t > &B_dims_, const std::array< std::vector< ck_tile::index_t >, NumDTensor > &Ds_dims_, const std::vector< ck_tile::index_t > &E_dims_, const std::vector< ck_tile::index_t > &A_strides_, const std::vector< ck_tile::index_t > &B_strides_, const std::array< std::vector< ck_tile::index_t >, NumDTensor > &Ds_strides_, const std::vector< ck_tile::index_t > &E_strides_)
Constructor for batched contraction host arguments.
Definition batched_contraction_kernel.hpp:117
const void * a_ptr
Pointer to input tensor A.
Definition batched_contraction_kernel.hpp:152
const std::vector< ck_tile::index_t > B_strides
Stride vector for tensor B: [G0, G1, ..., N0, N1, ..., K0, K1, ...].
Definition batched_contraction_kernel.hpp:168
Kernel arguments for batched tensor contraction operations.
Definition batched_contraction_kernel.hpp:189
const void * b_ptr
Definition batched_contraction_kernel.hpp:191
std::array< ck_tile::index_t, NumDTensor > stride_Ds
Definition batched_contraction_kernel.hpp:216
const void * a_ptr
Definition batched_contraction_kernel.hpp:190
ck_tile::index_t stride_A
Definition batched_contraction_kernel.hpp:213
ck_tile::index_t M_total
Definition batched_contraction_kernel.hpp:209
ck_tile::index_t G_total
Definition batched_contraction_kernel.hpp:208
ck_tile::index_t stride_E
Definition batched_contraction_kernel.hpp:217
std::array< ck_tile::index_t, NumDTensor > batch_stride_Ds
Definition batched_contraction_kernel.hpp:206
std::array< const void *, NumDTensor > ds_ptr
Definition batched_contraction_kernel.hpp:192
ck_tile::index_t M_dims[NumDimM]
Definition batched_contraction_kernel.hpp:196
ck_tile::index_t K_dims[NumDimK]
Definition batched_contraction_kernel.hpp:198
ck_tile::index_t stride_B
Definition batched_contraction_kernel.hpp:214
ck_tile::index_t G_dims[NumDimG]
Definition batched_contraction_kernel.hpp:200
ck_tile::index_t N_dims[NumDimN]
Definition batched_contraction_kernel.hpp:197
ck_tile::index_t batch_stride_E
Definition batched_contraction_kernel.hpp:205
ck_tile::index_t K_total
Definition batched_contraction_kernel.hpp:211
ck_tile::index_t batch_stride_A
Definition batched_contraction_kernel.hpp:203
ck_tile::index_t k_batch
Definition batched_contraction_kernel.hpp:194
ck_tile::index_t batch_stride_B
Definition batched_contraction_kernel.hpp:204
void * e_ptr
Definition batched_contraction_kernel.hpp:193
ck_tile::index_t N_total
Definition batched_contraction_kernel.hpp:210
GPU kernel for batched tensor contraction operations.
Definition batched_contraction_kernel.hpp:238
static CK_TILE_HOST constexpr auto GetBlockSize()
Returns the GPU block size for kernel launch.
Definition batched_contraction_kernel.hpp:318
static constexpr ck_tile::index_t NumDTensor
Number of auxiliary input D tensors.
Definition batched_contraction_kernel.hpp:259
static CK_TILE_HOST constexpr KernelArgs MakeKernelArgs(const BatchedContractionHostArgs< NumDTensor > &host_args)
Definition batched_contraction_kernel.hpp:330
static constexpr ck_tile::index_t NumDimM
Number of M (output row) dimensions.
Definition batched_contraction_kernel.hpp:253
ck_tile::UniversalGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ > UniversalGemmKernel
Definition batched_contraction_kernel.hpp:271
static CK_TILE_HOST constexpr bool IsSupportedArguments(const KernelArgs &kargs)
Validates whether the given kernel arguments are supported.
Definition batched_contraction_kernel.hpp:290
ck_tile::remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Epilogue pipeline for post-GEMM operations.
Definition batched_contraction_kernel.hpp:267
ck_tile::remove_cvref_t< Problem_ > Problem
Tensor contraction problem specification.
Definition batched_contraction_kernel.hpp:240
static CK_TILE_HOST constexpr auto GridSize(const KernelArgs &kargs)
Definition batched_contraction_kernel.hpp:323
static CK_TILE_HOST constexpr auto GetKernelName()
Returns the kernel name for debugging and profiling purposes.
Definition batched_contraction_kernel.hpp:284
static constexpr ck_tile::index_t NumDimG
Number of batch dimensions.
Definition batched_contraction_kernel.hpp:252
ck_tile::remove_cvref_t< GemmPipeline_ > GemmPipeline
GEMM computation pipeline.
Definition batched_contraction_kernel.hpp:266
CK_TILE_DEVICE void operator()(const KernelArgs &kargs) const
Definition batched_contraction_kernel.hpp:461
ck_tile::remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition batched_contraction_kernel.hpp:263
ck_tile::remove_cvref_t< typename Problem::ADataType > ADataType
Data type for input tensor A.
Definition batched_contraction_kernel.hpp:241
static CK_TILE_HOST constexpr ck_tile::index_t GetSmemSize()
Returns the shared memory size required by the kernel.
Definition batched_contraction_kernel.hpp:311
ck_tile::remove_cvref_t< typename Problem::EDataType > EDataType
Data type for output tensor E.
Definition batched_contraction_kernel.hpp:248
static constexpr ck_tile::index_t NumDimN
Number of N (output column) dimensions.
Definition batched_contraction_kernel.hpp:255
BatchedContractionKernelArgs< NumDimG, NumDimM, NumDimN, NumDimK, NumDTensor > KernelArgs
Definition batched_contraction_kernel.hpp:277
static constexpr ck_tile::index_t NumDimK
Number of K (contraction) dimensions.
Definition batched_contraction_kernel.hpp:257
ck_tile::remove_cvref_t< typename Problem::DsDataType > DsDataType
Definition batched_contraction_kernel.hpp:245
ck_tile::remove_cvref_t< typename Problem::BDataType > BDataType
Data type for input tensor B.
Definition batched_contraction_kernel.hpp:243
static constexpr ck_tile::index_t kBlockSize
GPU block size inherited from GEMM kernel.
Definition batched_contraction_kernel.hpp:274
Definition universal_gemm_kernel.hpp:325
std::array< index_t, NumATensor > as_k_split_offset
Definition universal_gemm_kernel.hpp:368
std::array< index_t, NumBTensor > bs_k_split_offset
Definition universal_gemm_kernel.hpp:369
The Universal GEMM kernel template.
Definition universal_gemm_kernel.hpp:154
static CK_TILE_DEVICE void RunGemm(const std::array< const ADataType *, NumATensor > &as_ptr, const std::array< const BDataType *, NumBTensor > &bs_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, void *smem_ptr_0, const KernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Runs single GEMM problem cooperatively by whole workgroup.
Definition universal_gemm_kernel.hpp:955
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSize()
Definition universal_gemm_kernel.hpp:319
static CK_TILE_HOST bool IsSupportedArgument(const KernelArgs &kargs)
Definition universal_gemm_kernel.hpp:373
static constexpr index_t kBlockSize
Definition universal_gemm_kernel.hpp:202
UniversalGemmKernelArgs< AsLayout::size(), BsLayout::size(), DsLayout::size()> KernelArgs
Definition universal_gemm_kernel.hpp:257