device_moe_mx_gemm_bpreshuffle.hpp Source File

device_moe_mx_gemm_bpreshuffle.hpp Source File#

Composable Kernel: device_moe_mx_gemm_bpreshuffle.hpp Source File
device_moe_mx_gemm_bpreshuffle.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8
19
20namespace ck {
21namespace tensor_operation {
22namespace device {
23
24template <typename ALayout,
25 typename BLayout,
26 typename DsLayout,
27 typename CLayout,
28 typename ADataType,
29 typename AScaleDataType,
30 typename BDataType,
31 typename BScaleDataType,
32 typename DsDataType,
33 typename CDataType,
34 typename GemmAccDataType,
35 typename CShuffleDataType,
36 typename AElementwiseOperation,
37 typename BElementwiseOperation,
38 typename CElementwiseOperation,
39 GemmSpecialization GemmSpec,
40 index_t ScaleBlockSize,
41 index_t BlockSize,
42 index_t MPerBlock,
43 index_t NPerBlock,
44 index_t KPerBlock,
45 index_t AK1,
46 index_t BK1,
47 index_t MPerXDL,
48 index_t NPerXDL,
49 index_t MXdlPerWave,
50 index_t NXdlPerWave,
51 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
52 typename ABlockTransferThreadClusterArrangeOrder,
53 typename ABlockTransferSrcAccessOrder,
54 index_t ABlockTransferSrcVectorDim,
55 index_t ABlockTransferSrcScalarPerVector,
56 index_t ABlockTransferDstScalarPerVector_AK1,
57 bool ABlockLdsExtraM,
58 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
59 typename BBlockTransferThreadClusterArrangeOrder,
60 typename BBlockTransferSrcAccessOrder,
61 index_t BBlockTransferSrcVectorDim,
62 index_t BBlockTransferSrcScalarPerVector,
63 index_t BBlockTransferDstScalarPerVector_BK1,
64 bool BBlockLdsExtraN,
65 index_t CShuffleMXdlPerWavePerShuffle,
66 index_t CShuffleNXdlPerWavePerShuffle,
67 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
68 typename CDEShuffleBlockTransferScalarPerVectors,
71 index_t ActivationOP = 0,
72 bool NSwizzle = false,
73 bool IsInputGemm = true,
74 bool MulRoutedWeight = true,
75 typename IndexType = index_t,
76 typename ComputeTypeA = ADataType,
77 typename ComputeTypeB = BDataType>
79 BLayout,
80 DsLayout,
81 CLayout,
82 ADataType,
83 AScaleDataType,
84 BDataType,
85 BScaleDataType,
86 DsDataType,
87 CDataType,
88 ScaleBlockSize,
89 AElementwiseOperation,
90 BElementwiseOperation,
91 CElementwiseOperation>
92{
94 static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
95 static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
96 static constexpr index_t NumDTensor = DsDataType::Size();
97 template <index_t NXdlPerWave_>
99 ALayout,
100 BLayout,
101 DsLayout,
102 CLayout,
103 ADataType,
104 AScaleDataType,
105 BDataType,
106 BScaleDataType,
107 GemmAccDataType,
108 CShuffleDataType,
109 DsDataType,
110 CDataType,
111 AElementwiseOperation,
112 BElementwiseOperation,
113 CElementwiseOperation,
114 GemmSpec,
115 ScaleBlockSize,
116 BlockSize,
117 MPerBlock,
118 NPerBlock,
119 KPerBlock,
120 AK1,
121 BK1,
122 MPerXDL,
123 NPerXDL,
124 MXdlPerWave,
125 math::max(2, NXdlPerWave_),
126 ABlockTransferThreadClusterLengths_AK0_M_AK1,
127 ABlockTransferThreadClusterArrangeOrder,
128 ABlockTransferSrcAccessOrder,
129 ABlockTransferSrcVectorDim,
130 ABlockTransferSrcScalarPerVector,
131 ABlockTransferDstScalarPerVector_AK1,
132 false,
133 ABlockLdsExtraM,
134 BBlockTransferThreadClusterLengths_BK0_N_BK1,
135 BBlockTransferThreadClusterArrangeOrder,
136 BBlockTransferSrcAccessOrder,
137 BBlockTransferSrcVectorDim,
138 BBlockTransferSrcScalarPerVector,
139 BBlockTransferDstScalarPerVector_BK1,
140 false,
141 BBlockLdsExtraN,
142 CShuffleMXdlPerWavePerShuffle,
143 CShuffleNXdlPerWavePerShuffle,
144 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
145 CDEShuffleBlockTransferScalarPerVectors,
146 BlkGemmPipeSched,
147 BlkGemmPipelineVer,
148 ActivationOP,
149 NSwizzle,
150 IsInputGemm,
151 MulRoutedWeight,
152 IndexType,
153 ComputeTypeA,
154 ComputeTypeB>;
157
158 using Argument = typename GridwiseGemm64::Argument;
159
162
163 int GetPreShuffleParameters() override { return NPerXDL; }
164
165 // Invoker
166 struct Invoker : public BaseInvoker
167 {
168 template <typename GridwiseGemm>
169 float RunImp(const typename GridwiseGemm::Argument& arg,
170 const StreamConfig& stream_config = StreamConfig{})
171 {
172 if(stream_config.log_level_ > 0)
173 {
174 arg.Print();
175 }
176
177 if(!GridwiseGemm::CheckValidity(arg))
178 {
179 throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
180 }
181
182 index_t gdx, gdy, gdz;
183 std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N);
184
185 float ave_time = 0;
186
187 index_t k_grain = arg.KBatch * KPerBlock;
188 index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
189
190 const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
191
192 const auto RunKernel = [&](const auto& kernel) {
193 if(stream_config.flush_cache)
194 {
195
196 std::array<std::size_t, NumDTensor> DsSize;
197
198 auto arg_ = arg;
199
200 const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
201 arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0);
202 const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(
203 arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0);
204
205 auto size_a_buffer =
206 a_grid_desc_ak0_m_ak1.GetElementSpaceSize() * sizeof(ADataType);
207 auto size_b_buffer =
208 b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType);
209
210 const auto ds_grid_desc_m_n = GridwiseGemm::MakeDsGridDescriptor_M_N(
211 arg_.M, arg_.MPadded, arg_.N, arg_.NPadded, arg_.StrideDs);
212
213 static_for<0, NumDTensor, 1>{}([&](auto i) {
214 using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
215 DsSize[i] = ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType);
216 });
217 ck::utility::RotatingMemWrapperMultiD<typename GridwiseGemm::Argument,
218 DsDataType>
219 rotating_mem(arg_,
220 stream_config.rotating_count,
221 size_a_buffer,
222 size_b_buffer,
223 DsSize);
224 rotating_mem.Print();
225
226 auto run_flush_cache = [&]() {
227 // flush icache
229 // rotating mem
230 rotating_mem.Next();
231 // clear c mem
232 if(arg_.KBatch > 1)
233 hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
234 0,
235 arg_.M * arg_.N * sizeof(CDataType),
236 stream_config.stream_id_));
237 };
238
240 stream_config,
241 run_flush_cache,
242 kernel,
243 dim3(gdx, gdy, gdz),
244 dim3(BlockSize),
245 0,
246 arg_);
247 }
248 else
249 {
250 if(arg.KBatch > 1)
251 hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
252 0,
253 arg.M * arg.N * sizeof(CDataType),
254 stream_config.stream_id_));
255
256 ave_time = launch_and_time_kernel(
257 stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
258 }
259 };
260
261 // TODO: Check if this is the right algorithm for minimum_occupancy
262 constexpr index_t minimum_occupancy =
263 BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave
264 ? (BlkGemmPipelineVer == BlockGemmPipelineVersion::v3 &&
265 MPerBlock * NPerBlock * KPerBlock * sizeof(ADataType) <= 128 * 128 * 64 * 2)
266 ? 2
267 : 1
268 : 2;
269
270 constexpr auto MemoryDataOp =
272
273 if(has_main_k_block_loop)
274 {
275 // Tail number always full
276 if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
277 {
278 {
279 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
280 {
281 const auto kernel = kernel_moe_mxgemm<GridwiseGemm,
282 true,
283 MemoryDataOp,
284 minimum_occupancy,
286 RunKernel(kernel);
287 }
288 else
289 {
290 const auto kernel = kernel_moe_mxgemm<GridwiseGemm,
291 true,
292 MemoryDataOp,
293 minimum_occupancy,
295 RunKernel(kernel);
296 }
297 }
298 }
299 else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
300 {
301 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
302 {
303 const auto kernel = kernel_moe_mxgemm_2lds<GridwiseGemm,
304 true,
305 MemoryDataOp,
306 minimum_occupancy,
308 RunKernel(kernel);
309 }
310 else
311 {
312 const auto kernel = kernel_moe_mxgemm_2lds<GridwiseGemm,
313 true,
314 MemoryDataOp,
315 minimum_occupancy,
317 RunKernel(kernel);
318 }
319 }
320 else
321 {
322 throw std::runtime_error("todo: only v1 & v3 support now");
323 }
324 }
325 else
326 {
327 if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
328 {
329 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
330 {
331 const auto kernel = kernel_moe_mxgemm<GridwiseGemm,
332 false,
333 MemoryDataOp,
334 minimum_occupancy,
336 RunKernel(kernel);
337 }
338 else
339 {
340 const auto kernel = kernel_moe_mxgemm<GridwiseGemm,
341 false,
342 MemoryDataOp,
343 minimum_occupancy,
345 RunKernel(kernel);
346 }
347 }
348 else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
349 {
350 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
351 {
352 const auto kernel = kernel_moe_mxgemm_2lds<GridwiseGemm,
353 false,
354 MemoryDataOp,
355 minimum_occupancy,
357 RunKernel(kernel);
358 }
359 else
360 {
361 const auto kernel = kernel_moe_mxgemm_2lds<GridwiseGemm,
362 false,
363 MemoryDataOp,
364 minimum_occupancy,
366 RunKernel(kernel);
367 }
368 }
369 }
370
371 return ave_time;
372 }
373
375
376 // polymorphic
377 float Run(const BaseArgument* p_arg,
378 const StreamConfig& stream_config = StreamConfig{}) override
379 {
380 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
381 }
382 };
383
384 static constexpr bool IsValidCompilationParameter()
385 {
386 // TODO: properly implement this check
387 return true;
388 }
389
390 static bool IsSupportedArgument(const Argument& arg)
391 {
392 // only impl kbatch 1 now
393 if(arg.KBatch > 1)
394 {
395 return false;
396 }
398 {
399 return false;
400 }
401 if(!is_bf16_atomic_supported() && std::is_same_v<CDataType, ck::bhalf_t> && arg.KBatch > 1)
402 {
403 return false;
404 }
405
406 if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
407 GemmSpec == GemmSpecialization::NKPadding ||
408 GemmSpec == GemmSpecialization::MNKPadding ||
409 GemmSpec == GemmSpecialization::KPadding))
410 {
411 return false;
412 }
413 if(arg.N % NPerBlock != 0 || arg.K % KPerBlock != 0)
414 {
415 return false;
416 }
417
418 if(get_warp_size() == 64)
419 {
420 if constexpr(NXdlPerWave64 > 0)
421 {
423 }
424 }
425 else
426 {
427 if constexpr(NXdlPerWave32 > 0)
428 {
430 reinterpret_cast<const typename GridwiseGemm32::Argument&>(arg));
431 }
432 }
433 return false;
434 }
435
436 // polymorphic
437 bool IsSupportedArgument(const BaseArgument* p_arg) override
438 {
439 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
440 }
441
442 static auto MakeArgument(const void* p_sorted_token_ids,
443 const void* p_sorted_expert_ids,
444 const void* p_max_token_id,
445 const void* p_a,
446 const void* p_a_scale,
447 const void* p_b,
448 const void* p_b_scale,
449 std::array<const void*, NumDTensor> p_ds,
450 void* p_c,
451 index_t NumTokens,
452 index_t TopK,
453 index_t M,
454 index_t N,
455 index_t K,
456 index_t StrideA,
457 index_t StrideScaleA,
458 index_t StrideB,
459 index_t StrideScaleB,
460 std::array<index_t, NumDTensor> StrideDs,
461 index_t StrideC,
462 index_t KBatch,
463 AElementwiseOperation a_element_op,
464 BElementwiseOperation b_element_op,
465 CElementwiseOperation c_element_op)
466 {
467 return Argument{static_cast<const index_t*>(p_sorted_token_ids),
468 static_cast<const index_t*>(p_sorted_expert_ids),
469 static_cast<const index_t*>(p_max_token_id),
470 static_cast<const ADataType*>(p_a),
471 static_cast<const AScaleDataType*>(p_a_scale),
472 static_cast<const BDataType*>(p_b),
473 static_cast<const BScaleDataType*>(p_b_scale),
474 p_ds,
475 static_cast<CDataType*>(p_c),
476 NumTokens,
477 TopK,
478 M,
479 N,
480 K,
481 StrideA,
482 StrideScaleA,
483 StrideB,
484 StrideScaleB,
485 StrideDs,
486 StrideC,
487 KBatch,
488 a_element_op,
489 b_element_op,
490 c_element_op};
491 }
492
493 static auto MakeInvoker() { return Invoker{}; }
494
495 // polymorphic
496 std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
497 const void* p_a_scale,
498 const void* p_b,
499 const void* p_b_scale,
500 std::array<const void*, NumDTensor> p_ds,
501 void* p_c,
502 index_t M,
503 index_t N,
504 index_t K,
505 index_t StrideA,
506 index_t StrideScaleA,
507 index_t StrideB,
508 index_t StrideScaleB,
509 std::array<ck::index_t, NumDTensor> StrideDs,
510 index_t StrideC,
511 index_t KBatch,
512 AElementwiseOperation a_element_op,
513 BElementwiseOperation b_element_op,
514 CElementwiseOperation c_element_op) override
515 {
516 return std::make_unique<Argument>(nullptr,
517 nullptr,
518 nullptr,
519 static_cast<const ADataType*>(p_a),
520 static_cast<const AScaleDataType*>(p_a_scale),
521 static_cast<const BDataType*>(p_b),
522 static_cast<const BScaleDataType*>(p_b_scale),
523 p_ds,
524 static_cast<CDataType*>(p_c),
525 M, // randoms set, no use
526 0,
527 M,
528 N,
529 K,
530 StrideA,
531 StrideScaleA,
532 StrideB,
533 StrideScaleB,
534 StrideDs,
535 StrideC,
536 KBatch,
537 a_element_op,
538 b_element_op,
539 c_element_op);
540 }
541
542 // polymorphic
543 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
544 {
545 return std::make_unique<Invoker>(Invoker{});
546 }
547
548 // polymorphic
549 std::string GetTypeString() const override
550 {
551 auto str = std::stringstream();
552
553 std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
556
557 std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
563
564 // clang-format off
565 str << "DeviceMoeGEmmMx"
566 << "<"
567 << getGemmSpecializationString(GemmSpec) << ", "
568 << std::string(ALayout::name)[0]
569 << std::string(BLayout::name)[0]
570 << std::string(CLayout::name)[0]
571 << ">"
572 << " BlkSize: "
573 << BlockSize << ", "
574 << "BlkTile: "
575 << MPerBlock<<"x"<<NPerBlock<<"x"<<KPerBlock << ", "
576 << "WaveTile: "
577 << MPerXDL<<"x"<<NPerXDL << ", "
578 << "WaveMap: "
579 << MXdlPerWave<<"x" << NXdlPerWave<<", "
580 << "VmemReadVec: "
581 << ABlockTransferSrcScalarPerVector<<"x"<<BBlockTransferSrcScalarPerVector<<", "
582 << "BlkGemmPipelineScheduler: "
583 << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
584 << "BlkGemmPipelineVersion: "
585 << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
586 << "BlkGemmPipelinePrefetchStages: "
587 << GridwiseGemm64::BlockwiseGemmPipe::PrefetchStages;
588 // clang-format on
589
590 return str.str();
591 }
592};
593
594} // namespace device
595} // namespace tensor_operation
596} // namespace ck
#define INVOKER_RUN3_IMPL
Definition device_base.hpp:114
#define GET_NXDL_PER_WAVE_IMPL
Definition device_base.hpp:81
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
Definition convolution_backward_data_specialization.hpp:8
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition gemm_specialization.hpp:32
GemmSpecialization
Definition gemm_specialization.hpp:11
@ MKPadding
Definition gemm_specialization.hpp:18
@ KPadding
Definition gemm_specialization.hpp:16
@ MNKPadding
Definition gemm_specialization.hpp:20
@ NKPadding
Definition gemm_specialization.hpp:19
Definition convolution_backward_data_specialization.hpp:7
void flush_icache()
Definition flush_cache.hpp:383
float launch_and_time_kernel_with_preprocess(const StreamConfig &stream_config, PreProcessFunc preprocess, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, GemmArgs &gemm_args, Args... args)
Definition flush_cache.hpp:398
Definition ck.hpp:268
__global__ void kernel_moe_mxgemm(typename GridwiseGemm::Argument karg)
Definition gridwise_moe_mx_gemm_bns.hpp:48
int32_t index_t
Definition ck.hpp:299
@ Set
Definition ck.hpp:278
@ AtomicAdd
Definition ck.hpp:279
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
BlockGemmPipelineVersion
Definition blkgemmpipe_scheduler.hpp:12
@ v2
Definition blkgemmpipe_scheduler.hpp:15
@ v3
Definition blkgemmpipe_scheduler.hpp:16
@ v5
Definition blkgemmpipe_scheduler.hpp:18
@ v4
Definition blkgemmpipe_scheduler.hpp:17
@ v1
Definition blkgemmpipe_scheduler.hpp:14
__global__ void kernel_moe_mxgemm_2lds(typename GridwiseGemm::Argument karg)
Definition gridwise_moe_mx_gemm.hpp:90
@ Even
Definition blkgemmpipe_scheduler.hpp:34
@ Odd
Definition blkgemmpipe_scheduler.hpp:33
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
bool is_xdl_wmma_supported()
Definition host_utility/device_prop.hpp:76
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
@ Interwave
Definition blkgemmpipe_scheduler.hpp:27
constexpr index_t packed_size_v
Definition data_type.hpp:411
bool is_bf16_atomic_supported()
Definition host_utility/device_prop.hpp:108
Definition ck/stream_config.hpp:10
Definition gridwise_moe_mx_gemm_bpreshuffle.hpp:174
Definition functional2.hpp:33
Definition device_base.hpp:197
Definition device_gemm_multiple_d.hpp:167
Definition device_moe_mx_gemm_bpreshuffle.hpp:167
float RunImp(const typename GridwiseGemm::Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_moe_mx_gemm_bpreshuffle.hpp:169
INVOKER_RUN3_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_moe_mx_gemm_bpreshuffle.hpp:377
Definition device_moe_mx_gemm_bpreshuffle.hpp:92
typename GridwiseGemm64::Argument Argument
Definition device_moe_mx_gemm_bpreshuffle.hpp:158
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_moe_mx_gemm_bpreshuffle.hpp:94
GridwiseMoeGemmMX_BPreshuffle< ALayout, BLayout, DsLayout, CLayout, ADataType, AScaleDataType, BDataType, BScaleDataType, GemmAccDataType, CShuffleDataType, DsDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, ScaleBlockSize, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, math::max(2, NXdlPerWave_), ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ActivationOP, NSwizzle, IsInputGemm, MulRoutedWeight, IndexType, ComputeTypeA, ComputeTypeB > GridwiseGemmBase
Definition device_moe_mx_gemm_bpreshuffle.hpp:98
static constexpr index_t APackedSize
Definition device_moe_mx_gemm_bpreshuffle.hpp:160
static constexpr index_t NumDTensor
Definition device_moe_mx_gemm_bpreshuffle.hpp:96
static constexpr bool IsValidCompilationParameter()
Definition device_moe_mx_gemm_bpreshuffle.hpp:384
static constexpr index_t BPackedSize
Definition device_moe_mx_gemm_bpreshuffle.hpp:161
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_moe_mx_gemm_bpreshuffle.hpp:543
static bool IsSupportedArgument(const Argument &arg)
Definition device_moe_mx_gemm_bpreshuffle.hpp:390
GridwiseGemmBase< NXdlPerWave32 > GridwiseGemm32
Definition device_moe_mx_gemm_bpreshuffle.hpp:156
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_moe_mx_gemm_bpreshuffle.hpp:437
int GetPreShuffleParameters() override
Definition device_moe_mx_gemm_bpreshuffle.hpp:163
std::string GetTypeString() const override
Definition device_moe_mx_gemm_bpreshuffle.hpp:549
static auto MakeInvoker()
Definition device_moe_mx_gemm_bpreshuffle.hpp:493
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_moe_mx_gemm_bpreshuffle.hpp:155
static constexpr auto NXdlPerWave32
Definition device_moe_mx_gemm_bpreshuffle.hpp:95
static auto MakeArgument(const void *p_sorted_token_ids, const void *p_sorted_expert_ids, const void *p_max_token_id, const void *p_a, const void *p_a_scale, const void *p_b, const void *p_b_scale, std::array< const void *, NumDTensor > p_ds, void *p_c, index_t NumTokens, index_t TopK, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideScaleA, index_t StrideB, index_t StrideScaleB, std::array< index_t, NumDTensor > StrideDs, index_t StrideC, index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition device_moe_mx_gemm_bpreshuffle.hpp:442
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_a_scale, const void *p_b, const void *p_b_scale, std::array< const void *, NumDTensor > p_ds, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideScaleA, index_t StrideB, index_t StrideScaleB, std::array< ck::index_t, NumDTensor > StrideDs, index_t StrideC, index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) override
Definition device_moe_mx_gemm_bpreshuffle.hpp:496
Definition flush_cache.hpp:174