blockwise_gemm_pipeline_xdlops_mx_moe_selector.hpp Source File

blockwise_gemm_pipeline_xdlops_mx_moe_selector.hpp Source File#

Composable Kernel: blockwise_gemm_pipeline_xdlops_mx_moe_selector.hpp Source File
blockwise_gemm_pipeline_xdlops_mx_moe_selector.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
8
9namespace ck {
10template <BlockGemmPipelineVersion BlkGemmPipelineVer,
11 BlockGemmPipelineScheduler BlkGemmPipeSche,
12 index_t ThreadBlockSize,
13 index_t ScaleBlockSize,
14 typename ADataType,
15 typename AScaleDataType,
16 typename BDataType,
17 typename BScaleDataType,
18 typename ComputeDataType, // TODO: remove this as in this pipeline ADataType and BDataType
19 // must be used for compute
20 typename AccDataType,
21 typename ATileDesc,
22 typename BTileDesc,
23 typename AMmaTileDesc,
24 typename BMmaTileDesc,
25 index_t ABlockTransferSrcScalarPerVector,
26 index_t BBlockTransferSrcScalarPerVector,
27 index_t MPerBlock,
28 index_t NPerBlock,
29 index_t KPerBlock,
30 index_t MPerXDL,
31 index_t NPerXDL,
32 index_t MRepeat,
33 index_t NRepeat,
34 index_t KPack,
35 bool GUFusion = false>
37{
38
39 // Hardware MX GEMM pipeline
40 if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
41 {
42 if constexpr(GUFusion)
43 {
44 return nullptr;
45 }
46 else
47 {
48 return nullptr;
49 }
50 }
51 else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
52 {
53 if constexpr(GUFusion)
54 {
56 BlkGemmPipeSche,
57 ThreadBlockSize,
58 ScaleBlockSize,
59 ADataType,
60 AScaleDataType,
61 BDataType,
62 BScaleDataType,
63 ATileDesc,
64 BTileDesc,
65 AMmaTileDesc,
66 BMmaTileDesc,
67 ABlockTransferSrcScalarPerVector,
68 BBlockTransferSrcScalarPerVector,
69 MPerBlock,
70 NPerBlock,
71 KPerBlock,
72 MPerXDL,
73 NPerXDL,
74 MRepeat,
75 NRepeat,
76 KPack>{};
77 }
78 else
79 {
80 return BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3<BlkGemmPipeSche,
81 ThreadBlockSize,
82 ScaleBlockSize,
83 ADataType,
84 AScaleDataType,
85 BDataType,
86 BScaleDataType,
87 ATileDesc,
88 BTileDesc,
89 AMmaTileDesc,
90 BMmaTileDesc,
91 ABlockTransferSrcScalarPerVector,
92 BBlockTransferSrcScalarPerVector,
93 MPerBlock,
94 NPerBlock,
95 KPerBlock,
96 MPerXDL,
97 NPerXDL,
98 MRepeat,
99 NRepeat,
100 KPack>{};
101 }
102 }
103 else
104 {
105 std::cerr << "MX GEMM Pipeline configuration is not available" << std::endl;
106 }
107}
108
109} // namespace ck
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
constexpr auto BlockGemmMXPipeline_Selector()
Definition blockwise_gemm_pipeline_xdlops_mx_moe_selector.hpp:36
BlockGemmPipelineVersion
Definition blkgemmpipe_scheduler.hpp:12
@ v3
Definition blkgemmpipe_scheduler.hpp:16
@ v1
Definition blkgemmpipe_scheduler.hpp:14
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
Definition blockwise_gemm_pipeline_xdlops_mx_moe_gufusion_v3.hpp:38
Definition blockwise_gemm_pipeline_xdlops_mx_moe_nbs_v3.hpp:38