9#include <initializer_list>
33template <
typename DeviceOp,
39 typename AElementwiseOperation,
40 typename B0ElementwiseOperation,
41 typename AccElementwiseOperation,
42 typename B1ElementwiseOperation,
43 typename CElementwiseOperation,
44 bool HasMainKBlockLoop>
46#if CK_USE_LAUNCH_BOUNDS
50 const B0DataType* __restrict__ p_b0_grid,
51 const B1DataType* __restrict__ p_b1_grid,
52 CDataType* __restrict__ p_c_grid,
63#if(defined(__gfx11__) || defined(__gfx12__))
67 const auto q_head = G1;
68 const auto kv_head = 1;
70 constexpr index_t array_size = 4;
71 std::array<ck::index_t, array_size> a_gs_ms_ks_lengths{G0, q_head, M, K};
72 std::array<ck::index_t, array_size> a_gs_ms_ks_strides =
74 ? std::array<ck::index_t, array_size>{M * q_head * K, K, q_head * K, 1}
75 : std::array<ck::index_t, array_size>{q_head * M * K, M * K, K, 1};
77 std::array<ck::index_t, array_size> b0_gs_ns_ks_lengths{G0, kv_head, N, K};
78 std::array<ck::index_t, array_size> b0_gs_ns_ks_strides =
80 ? std::array<ck::index_t, array_size>{N * kv_head * K, K, kv_head * K, 1}
81 : std::array<ck::index_t, array_size>{kv_head * N * K, N * K, K, 1};
83 std::array<ck::index_t, array_size> b1_gs_os_ns_lengths{G0, kv_head, O, N};
84 std::array<ck::index_t, array_size> b1_gs_os_ns_strides =
86 ? std::array<ck::index_t, array_size>{N * kv_head * O, O, 1, kv_head * O}
87 : std::array<ck::index_t, array_size>{kv_head * N * O, N * O, 1, O};
89 std::array<ck::index_t, array_size> c_gs_ms_os_lengths{G0, q_head, M, O};
90 std::array<ck::index_t, array_size> c_gs_ms_os_strides =
92 ? std::array<ck::index_t, array_size>{M * q_head * O, O, q_head * O, 1}
93 : std::array<ck::index_t, array_size>{q_head * M * O, M * O, O, 1};
95 const auto a_element_op = AElementwiseOperation{};
96 const auto b0_element_op = B0ElementwiseOperation{};
97 const auto acc0_element_op = AccElementwiseOperation{alpha};
98 const auto b1_element_op = B1ElementwiseOperation{};
99 const auto c_element_op = CElementwiseOperation{};
102 const auto a_grid_desc = DeviceOp::MakeAGridDescriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides);
103 const auto b0_grid_desc =
104 DeviceOp::MakeB0GridDescriptor(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides);
105 const auto b1_grid_desc =
106 DeviceOp::MakeB1GridDescriptor(b1_gs_os_ns_lengths, b1_gs_os_ns_strides);
107 const auto c_grid_desc_m_n =
108 DeviceOp::Transform::MakeCGridDescriptor_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides);
109 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
110 GridwiseOp::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n);
111 const auto block_2_ctile_map = GridwiseOp::MakeDefaultBlock2CTileMap(c_grid_desc_m_n, 1, 1);
113 const auto a_grid_desc_g_m_k =
114 DeviceOp::Transform::MakeAGridDescriptor_G_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides);
115 const auto b0_grid_desc_g_l_k =
116 DeviceOp::Transform::MakeB0GridDescriptor_G_N_K(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides);
117 const auto b1_grid_desc_g_n_l =
118 DeviceOp::Transform::MakeB1GridDescriptor_G_N_K(b1_gs_os_ns_lengths, b1_gs_os_ns_strides);
119 const auto c_grid_desc_g_m_n =
120 DeviceOp::Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides);
121 const auto compute_base_ptr_of_batch =
122 typename DeviceOp::ComputeBasePtrOfStridedBatch{a_grid_desc_g_m_k, b0_grid_desc_g_l_k, b1_grid_desc_g_n_l, c_grid_desc_g_m_n};
124 const auto c0_matrix_mask =
typename DeviceOp::C0MatrixMask{b0_grid_desc_g_l_k.GetLength(
Number<1>{})};
127 __shared__
char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()];
128 const index_t num_blocks_per_batch =
129 __builtin_amdgcn_readfirstlane(
get_grid_size() / batch_count);
132 const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
133 static_cast<long_index_t>(compute_base_ptr_of_batch.GetABasePtr(g_idx)));
134 const long_index_t b0_batch_offset = __builtin_amdgcn_readfirstlane(
135 static_cast<long_index_t>(compute_base_ptr_of_batch.GetB0BasePtr(g_idx / G1)));
136 const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane(
137 static_cast<long_index_t>(compute_base_ptr_of_batch.GetB1BasePtr(g_idx / G1)));
138 const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
139 static_cast<long_index_t>(compute_base_ptr_of_batch.GetCBasePtr(g_idx)));
141 GridwiseOp::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
142 p_b0_grid + b0_batch_offset,
143 p_b1_grid + b1_batch_offset,
144 p_c_grid + c_batch_offset,
149 c_grid_desc_mblock_mperblock_nblock_nperblock,
187 typename Acc0BiasDataType,
188 typename Acc0DataType,
189 typename Acc1BiasDataType,
190 typename Acc1DataType,
191 typename CShuffleDataType,
192 typename AElementwiseOperation,
193 typename B0ElementwiseOperation,
194 typename AccElementwiseOperation,
195 typename B1ElementwiseOperation,
196 typename CElementwiseOperation,
218 typename ABlockTransferThreadClusterLengths_K0_M_K1,
219 typename ABlockTransferThreadClusterArrangeOrder,
220 typename ABlockTransferSrcAccessOrder,
224 bool ABlockLdsAddExtraM,
225 typename B0BlockTransferThreadClusterLengths_K0_L_K1,
226 typename B0BlockTransferThreadClusterArrangeOrder,
227 typename B0BlockTransferSrcAccessOrder,
231 bool B0BlockLdsAddExtraL,
232 typename B1BlockTransferThreadClusterLengths_L0_N_L1,
233 typename B1BlockTransferThreadClusterArrangeOrder,
234 typename B1BlockTransferSrcAccessOrder,
238 bool B1BlockLdsAddExtraN,
239 index_t CShuffleMRepeatPerShuffle,
240 index_t CShuffleNRepeatPerShuffle,
241 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
242 index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
258 AElementwiseOperation,
259 B0ElementwiseOperation,
260 AccElementwiseOperation,
261 B1ElementwiseOperation,
262 CElementwiseOperation,
265 static_assert(NumDimG > 0 && NumDimM > 0 && NumDimL > 0 && NumDimK > 0 && NumDimN > 0,
266 "Number of dimension must be greater than 0");
293 static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
294 static constexpr auto LWaves = LPerBlock / (LRepeat * LPerWmma);
295 static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
319 const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_lengths_vec,
320 const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_strides_vec)
333 a_gs_ms_ks_strides_vec),
343 const std::array<index_t, NumDimG + NumDimM + NumDimN>& b0_gs_ls_ks_lengths_vec,
344 const std::array<index_t, NumDimG + NumDimM + NumDimN>& b0_gs_ls_ks_strides_vec)
350 b0_gs_ls_ks_strides_vec),
358 b0_gs_ls_ks_strides_vec),
368 const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_ns_ls_lengths_vec,
369 const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_ns_ls_strides_vec)
375 b1_gs_ns_ls_strides_vec),
383 b1_gs_ns_ls_strides_vec),
420 : a_grid_desc_g_m_k_(a_grid_desc_g_m_k),
421 b0_grid_desc_g_l_k_(b0_grid_desc_g_l_k),
422 b1_grid_desc_g_n_l_(b1_grid_desc_g_n_l),
423 c_grid_desc_g_m_n_(c_grid_desc_g_m_n)
465 AElementwiseOperation,
466 B0ElementwiseOperation,
467 AccElementwiseOperation,
468 B1ElementwiseOperation,
469 CElementwiseOperation,
493 ABlockTransferThreadClusterLengths_K0_M_K1,
494 ABlockTransferThreadClusterArrangeOrder,
495 ABlockTransferSrcAccessOrder,
496 ABlockTransferSrcVectorDim,
497 ABlockTransferSrcScalarPerVector,
498 ABlockTransferDstScalarPerVector_K1,
502 B0BlockTransferThreadClusterLengths_K0_L_K1,
503 B0BlockTransferThreadClusterArrangeOrder,
504 B0BlockTransferSrcAccessOrder,
505 B0BlockTransferSrcVectorDim,
506 B0BlockTransferSrcScalarPerVector,
507 B0BlockTransferDstScalarPerVector_K1,
511 B1BlockTransferThreadClusterLengths_L0_N_L1,
512 B1BlockTransferThreadClusterArrangeOrder,
513 B1BlockTransferSrcAccessOrder,
514 B1BlockTransferSrcVectorDim,
515 B1BlockTransferSrcScalarPerVector,
516 B1BlockTransferDstScalarPerVector_L1,
520 CShuffleMRepeatPerShuffle,
521 CShuffleNRepeatPerShuffle,
522 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
523 CShuffleBlockTransferScalarPerVector_NPerBlock,
533 const B0DataType* p_b0_grid,
534 const B1DataType* p_b1_grid,
579 const B0DataType* p_b0,
580 const B1DataType* p_b1,
593 p_a, p_b0, p_b1, p_c, M, N, K, O, G0, G1, alpha, input_permute, output_permute};
602 printf(
"DeviceOp: Acc0 Type err");
608 printf(
"DeviceOp: Acc1 Type err");
614 printf(
"DeviceOp: Arch err");
618 constexpr index_t array_size = 4;
628 std::array<ck::index_t, array_size> a_gs_ms_ks_lengths{G0, G1, M, K};
629 std::array<ck::index_t, array_size> a_gs_ms_ks_strides =
630 input_permute ? std::array<ck::index_t, array_size>{M * G1 * K, K, G1 * K, 1}
632 : std::array<ck::index_t, array_size>{
633 G1 * M * K, M * K, K, 1};
635 std::array<ck::index_t, array_size> b0_gs_ns_ks_lengths{G0, G1, N, K};
636 std::array<ck::index_t, array_size> b0_gs_ns_ks_strides =
637 input_permute ? std::array<ck::index_t, array_size>{N * G1 * K, K, G1 * K, 1}
639 : std::array<ck::index_t, array_size>{
640 G1 * N * K, N * K, K, 1};
642 std::array<ck::index_t, array_size> b1_gs_os_ns_lengths{G0, G1, O, N};
643 std::array<ck::index_t, array_size> b1_gs_os_ns_strides =
644 input_permute ? std::array<ck::index_t, array_size>{N * G1 * O, O, 1, G1 * O}
646 : std::array<ck::index_t, array_size>{
647 G1 * N * O, N * O, 1, O};
649 std::array<ck::index_t, array_size> c_gs_ms_os_lengths{G0, G1, M, O};
650 std::array<ck::index_t, array_size> c_gs_ms_os_strides =
651 output_permute ? std::array<ck::index_t, array_size>{M * G1 * O, O, G1 * O, 1}
653 : std::array<ck::index_t, array_size>{
654 G1 * M * O, M * O, O, 1};
656 const auto a_grid_desc =
658 const auto b0_grid_desc =
660 const auto b1_grid_desc =
662 const auto c_grid_desc_m_n =
663 DeviceOp::Transform::MakeCGridDescriptor_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides);
667 const auto c_grid_desc_g_m_n =
668 DeviceOp::Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides);
672 a_grid_desc, b0_grid_desc, b1_grid_desc, c_grid_desc_m_n, block_2_ctile_map))
678 const index_t c_g = c_grid_desc_g_m_n.GetLength(
I0);
680 if(!(c_g == batch_count))
682 printf(
"DeviceOp: BatchCount err");
689 const auto MzRaw = M;
690 const auto LzRaw = N;
691 const auto KzRaw = K;
692 const auto NzRaw = O;
695 const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? KzRaw : MzRaw;
696 const auto b0_extent_lowest = B0BlockTransferSrcVectorDim == 2 ? KzRaw : LzRaw;
697 const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? LzRaw : NzRaw;
698 const auto c_extent_lowest = NzRaw;
700 if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 &&
701 b0_extent_lowest % B0BlockTransferSrcScalarPerVector == 0 &&
702 b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 &&
703 c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0))
705 printf(
"DeviceOp: Data Transfer Vector scalar err");
709 std::array<index_t, NumDimG + NumDimM + NumDimN> a_mz_kz_strides_{
710 a_gs_ms_ks_strides[NumDimG + NumDimM - 1],
711 a_gs_ms_ks_strides[NumDimG + NumDimM + NumDimK - 1]};
712 std::array<index_t, NumDimG + NumDimM + NumDimN> b0_lz_kz_strides_{
713 b0_gs_ns_ks_strides[NumDimG + NumDimL - 1],
714 b0_gs_ns_ks_strides[NumDimG + NumDimL + NumDimK - 1]};
715 std::array<index_t, NumDimG + NumDimM + NumDimN> b1_nz_lz_strides_{
716 b1_gs_os_ns_strides[NumDimG + NumDimN - 1],
717 b1_gs_os_ns_strides[NumDimG + NumDimN + NumDimL - 1]};
718 std::array<index_t, NumDimG + NumDimM + NumDimN> c_mz_nz_strides_{
719 c_gs_ms_os_strides[NumDimG + NumDimM - 1],
720 c_gs_ms_os_strides[NumDimG + NumDimM + NumDimN - 1]};
723 const auto a_stride_lowest =
724 ABlockTransferSrcVectorDim == 2 ? a_mz_kz_strides_[1] : a_mz_kz_strides_[0];
725 const auto b0_stride_lowest =
726 B0BlockTransferSrcVectorDim == 2 ? b0_lz_kz_strides_[1] : b0_lz_kz_strides_[0];
727 const auto b1_stride_lowest =
728 B1BlockTransferSrcVectorDim == 2 ? b1_nz_lz_strides_[1] : b1_nz_lz_strides_[0];
729 const auto c_stride_lowest = c_mz_nz_strides_[1];
731 if(!(a_stride_lowest == 1 || b0_stride_lowest == 1 || b1_stride_lowest == 1 ||
732 c_stride_lowest == 1))
734 printf(
"DeviceOp: Data Vectorize transfer err");
751 const ADataType* p_a_grid,
752 const B0DataType* p_b0_grid,
753 const B1DataType* p_b1_grid,
755 const std::array<void*, NumAcc0Bias> p_acc0_biases,
756 const std::array<void*, NumAcc1Bias> p_acc1_biases,
757 const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_lengths,
758 const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_strides,
759 const std::array<index_t, NumDimG + NumDimM + NumDimN>& b0_gs_ls_ks_lengths,
760 const std::array<index_t, NumDimG + NumDimM + NumDimN>& b0_gs_ls_ks_strides,
761 const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_ns_ls_lengths,
762 const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_ns_ls_strides,
763 const std::array<index_t, NumDimG + NumDimM + NumDimN>& c_gs_ms_ns_lengths,
764 const std::array<index_t, NumDimG + NumDimM + NumDimN>& c_gs_ms_ns_strides,
765 const std::array<std::vector<ck::index_t>,
NumAcc0Bias> acc0_biases_gs_ms_ls_lengths,
766 const std::array<std::vector<ck::index_t>,
NumAcc0Bias> acc0_biases_gs_ms_ls_strides,
767 const std::array<std::vector<ck::index_t>,
NumAcc1Bias> acc1_biases_gs_ms_ns_lengths,
768 const std::array<std::vector<ck::index_t>,
NumAcc1Bias> acc1_biases_gs_ms_ns_strides,
771 AElementwiseOperation a_element_op,
772 B0ElementwiseOperation b0_element_op,
773 AccElementwiseOperation acc_element_op,
774 B1ElementwiseOperation b1_element_op,
775 CElementwiseOperation c_element_op)
786 Transform::MakeCGridDescriptor_M_N(c_gs_ms_ns_lengths, c_gs_ms_ns_strides)},
788 Transform::MakeAGridDescriptor_G_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)},
790 Transform::MakeB0GridDescriptor_G_N_K(b0_gs_ls_ks_lengths, b0_gs_ls_ks_strides)},
792 Transform::MakeB1GridDescriptor_G_N_K(b1_gs_ns_ls_lengths, b1_gs_ns_ls_strides)},
794 Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_ns_lengths, c_gs_ms_ns_strides)},
804 b0_gs_ls_ks_lengths[NumDimG + NumDimL - 1],
805 b0_gs_ls_ks_lengths[NumDimG + NumDimL + NumDimK - 1],
806 b1_gs_ns_ls_lengths[NumDimG + NumDimN - 1]},
808 a_gs_ms_ks_strides[NumDimG + NumDimM + NumDimK - 1]},
810 b0_gs_ls_ks_strides[NumDimG + NumDimL + NumDimK - 1]},
812 b1_gs_ns_ls_strides[NumDimG + NumDimN + NumDimL - 1]},
814 c_gs_ms_ns_strides[NumDimG + NumDimM + NumDimN - 1]},
822 ignore = acc0_biases_gs_ms_ls_lengths;
823 ignore = acc0_biases_gs_ms_ls_strides;
824 ignore = acc1_biases_gs_ms_ns_lengths;
825 ignore = acc1_biases_gs_ms_ns_strides;
892 const auto K = arg.
K_;
894 auto launch_kernel = [&](
auto has_main_k_block_loop) {
901 AElementwiseOperation,
902 B0ElementwiseOperation,
903 AccElementwiseOperation,
904 B1ElementwiseOperation,
905 CElementwiseOperation,
906 has_main_k_block_loop>;
942 return Run(*
dynamic_cast<const Argument*
>(p_arg), stream_config);
958 printf(
"DeviceOp: Acc0 Type err");
964 printf(
"DeviceOp: Acc1 Type err");
970 printf(
"DeviceOp: Arch err");
977 arg.c_grid_desc_m_n_,
978 arg.block_2_ctile_map_))
984 const index_t c_g = arg.c_grid_desc_g_m_n_.GetLength(
I0);
986 if(!(c_g == arg.batch_count_))
988 printf(
"DeviceOp: BatchCount err");
995 const auto MzRaw = arg.raw_lengths_mz_lz_kz_nz_[0];
996 const auto LzRaw = arg.raw_lengths_mz_lz_kz_nz_[1];
997 const auto KzRaw = arg.raw_lengths_mz_lz_kz_nz_[2];
998 const auto NzRaw = arg.raw_lengths_mz_lz_kz_nz_[3];
1001 const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? KzRaw : MzRaw;
1002 const auto b0_extent_lowest = B0BlockTransferSrcVectorDim == 2 ? KzRaw : LzRaw;
1003 const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? LzRaw : NzRaw;
1004 const auto c_extent_lowest = NzRaw;
1006 if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 &&
1007 b0_extent_lowest % B0BlockTransferSrcScalarPerVector == 0 &&
1008 b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 &&
1009 c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0))
1011 printf(
"DeviceOp: Data Transfer Vector scalar err");
1016 const auto a_stride_lowest =
1017 ABlockTransferSrcVectorDim == 2 ? arg.a_mz_kz_strides_[1] : arg.a_mz_kz_strides_[0];
1018 const auto b0_stride_lowest =
1019 B0BlockTransferSrcVectorDim == 2 ? arg.b0_lz_kz_strides_[1] : arg.b0_lz_kz_strides_[0];
1020 const auto b1_stride_lowest =
1021 B1BlockTransferSrcVectorDim == 2 ? arg.b1_nz_lz_strides_[1] : arg.b1_nz_lz_strides_[0];
1022 const auto c_stride_lowest = arg.c_mz_nz_strides_[1];
1024 if(!(a_stride_lowest == 1 || b0_stride_lowest == 1 || b1_stride_lowest == 1 ||
1025 c_stride_lowest == 1))
1027 printf(
"DeviceOp: Data Vectorize transfer err");
1041 const ADataType* p_a,
1042 const B0DataType* p_b0,
1043 const B1DataType* p_b1,
1045 const std::array<void*, NumAcc0Bias> p_acc0_biases,
1046 const std::array<void*, NumAcc1Bias> p_acc1_biases,
1047 const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_lengths,
1048 const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_strides,
1049 const std::array<index_t, NumDimG + NumDimM + NumDimN>& b0_gs_ls_ks_lengths,
1050 const std::array<index_t, NumDimG + NumDimM + NumDimN>& b0_gs_ls_ks_strides,
1051 const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_ns_ls_lengths,
1052 const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_ns_ls_strides,
1053 const std::array<index_t, NumDimG + NumDimM + NumDimN>& c_gs_ms_ns_lengths,
1054 const std::array<index_t, NumDimG + NumDimM + NumDimN>& c_gs_ms_ns_strides,
1055 const std::array<std::vector<ck::index_t>,
NumAcc0Bias> acc0_biases_gs_ms_ls_lengths,
1056 const std::array<std::vector<ck::index_t>,
NumAcc0Bias> acc0_biases_gs_ms_ls_strides,
1057 const std::array<std::vector<ck::index_t>,
NumAcc1Bias> acc1_biases_gs_ms_ns_lengths,
1058 const std::array<std::vector<ck::index_t>,
NumAcc1Bias> acc1_biases_gs_ms_ns_strides,
1059 AElementwiseOperation a_element_op,
1060 B0ElementwiseOperation b0_element_op,
1061 AccElementwiseOperation acc_element_op,
1062 B1ElementwiseOperation b1_element_op,
1063 CElementwiseOperation c_element_op)
1073 b0_gs_ls_ks_lengths,
1074 b0_gs_ls_ks_strides,
1075 b1_gs_ns_ls_lengths,
1076 b1_gs_ns_ls_strides,
1079 acc0_biases_gs_ms_ls_lengths,
1080 acc0_biases_gs_ms_ls_strides,
1081 acc1_biases_gs_ms_ns_lengths,
1082 acc1_biases_gs_ms_ns_strides,
1099 const std::array<void*, NumAcc0Bias> p_acc0_biases,
1100 const std::array<void*, NumAcc1Bias> p_acc1_biases,
1101 const std::vector<index_t>& a_gs_ms_ks_lengths,
1102 const std::vector<index_t>& a_gs_ms_ks_strides,
1103 const std::vector<index_t>& b0_gs_ls_ks_lengths,
1104 const std::vector<index_t>& b0_gs_ls_ks_strides,
1105 const std::vector<index_t>& b1_gs_ns_ls_lengths,
1106 const std::vector<index_t>& b1_gs_ns_ls_strides,
1107 const std::vector<index_t>& c_gs_ms_ns_lengths,
1108 const std::vector<index_t>& c_gs_ms_ns_strides,
1109 const std::array<std::vector<ck::index_t>,
NumAcc0Bias> acc0_biases_gs_ms_ls_lengths,
1110 const std::array<std::vector<ck::index_t>,
NumAcc0Bias> acc0_biases_gs_ms_ls_strides,
1111 const std::array<std::vector<ck::index_t>,
NumAcc1Bias> acc1_biases_gs_ms_ns_lengths,
1112 const std::array<std::vector<ck::index_t>,
NumAcc1Bias> acc1_biases_gs_ms_ns_strides,
1113 AElementwiseOperation a_element_op,
1114 B0ElementwiseOperation b0_element_op,
1115 AccElementwiseOperation acc_element_op,
1116 B1ElementwiseOperation b1_element_op,
1117 CElementwiseOperation c_element_op)
override
1119 std::array<index_t, NumDimG + NumDimM + NumDimN> a_lengths;
1120 std::array<index_t, NumDimG + NumDimM + NumDimN> a_strides;
1121 std::array<index_t, NumDimG + NumDimM + NumDimN> b0_lengths;
1122 std::array<index_t, NumDimG + NumDimM + NumDimN> b0_strides;
1123 std::array<index_t, NumDimG + NumDimM + NumDimN> b1_lengths;
1124 std::array<index_t, NumDimG + NumDimM + NumDimN> b1_strides;
1125 std::array<index_t, NumDimG + NumDimM + NumDimN> c_lengths;
1126 std::array<index_t, NumDimG + NumDimM + NumDimN> c_strides;
1127 std::transform(a_gs_ms_ks_lengths.begin(),
1128 a_gs_ms_ks_lengths.end(),
1131 std::transform(a_gs_ms_ks_strides.begin(),
1132 a_gs_ms_ks_strides.end(),
1135 std::transform(b0_gs_ls_ks_lengths.begin(),
1136 b0_gs_ls_ks_lengths.end(),
1139 std::transform(b0_gs_ls_ks_strides.begin(),
1140 b0_gs_ls_ks_strides.end(),
1143 std::transform(b1_gs_ns_ls_lengths.begin(),
1144 b1_gs_ns_ls_lengths.end(),
1147 std::transform(b1_gs_ns_ls_strides.begin(),
1148 b1_gs_ns_ls_strides.end(),
1151 std::transform(c_gs_ms_ns_lengths.begin(),
1152 c_gs_ms_ns_lengths.end(),
1155 std::transform(c_gs_ms_ns_strides.begin(),
1156 c_gs_ms_ns_strides.end(),
1159 return std::make_unique<Argument>(
static_cast<const ADataType*
>(p_a),
1160 static_cast<const B0DataType*
>(p_b0),
1161 static_cast<const B1DataType*
>(p_b1),
1162 static_cast<CDataType*
>(p_c),
1173 acc0_biases_gs_ms_ls_lengths,
1174 acc0_biases_gs_ms_ls_strides,
1175 acc1_biases_gs_ms_ns_lengths,
1176 acc1_biases_gs_ms_ns_strides,
1191 return std::make_unique<Invoker>(
Invoker{});
1197 auto str = std::stringstream();
1199 std::map<LoopScheduler, std::string> LoopSchedToString{
1202 std::map<PipelineVersion, std::string> PipelineVersionToString{{
PipelineVersion::v1,
"v1"},
1206 str <<
"DeviceMultiQueryAttentionForward_Wmma"
1208 << BlockSize <<
", "
1209 << MPerBlock <<
", "
1210 << LPerBlock <<
", "
1211 << KPerBlock <<
", "
1214 << MPerBlock <<
", "
1215 << NPerBlock <<
", "
1216 << LTilePerBlock <<
", "
1232 << NumPrefetch <<
", "
1233 <<
"LoopScheduler: "
1234 << LoopSchedToString[LoopSched] <<
", "
1235 <<
"PipelineVersion: "
1236 << PipelineVersionToString[PipelineVer];
#define CK_MIN_BLOCK_PER_CU
Definition ck.hpp:31
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
Definition convolution_backward_data_specialization.hpp:8
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition gemm_specialization.hpp:32
std::string getMaskingSpecializationString(const MaskingSpecialization &s)
Definition masking_specialization.hpp:17
MaskingSpecialization
Definition masking_specialization.hpp:11
@ MaskDisabled
Definition masking_specialization.hpp:12
@ MaskOutUpperTriangle
Definition masking_specialization.hpp:13
TensorSpecialization
Definition tensor_specialization.hpp:11
GemmSpecialization
Definition gemm_specialization.hpp:11
__global__ void kernel_multi_query_attention_wmma(const ADataType *__restrict__ p_a_grid, const B0DataType *__restrict__ p_b0_grid, const B1DataType *__restrict__ p_b1_grid, CDataType *__restrict__ p_c_grid, index_t M, index_t N, index_t K, index_t O, index_t G0, index_t G1, float alpha, bool input_permute, bool output_permute)
Definition device_multi_query_attention_forward_wmma.hpp:49
std::string getTensorSpecializationString(const TensorSpecialization &s)
Definition tensor_specialization.hpp:16
Definition convolution_backward_data_specialization.hpp:7
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
__device__ index_t get_grid_size()
Definition get_id.hpp:49
int32_t index_t
Definition ck.hpp:299
@ Set
Definition ck.hpp:278
integral_constant< index_t, N > Number
Definition number.hpp:12
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
bool is_gfx12_supported()
Definition host_utility/device_prop.hpp:55
constexpr bool is_same_v
Definition type.hpp:283
LoopScheduler
Definition loop_scheduler.hpp:15
@ Default
Definition loop_scheduler.hpp:16
@ Interwave
Definition loop_scheduler.hpp:17
int64_t long_index_t
Definition ck.hpp:300
PipelineVersion
Definition gridwise_gemm_pipeline_selector.hpp:18
@ v2
Definition gridwise_gemm_pipeline_selector.hpp:20
@ v1
Definition gridwise_gemm_pipeline_selector.hpp:19
bool is_gfx11_supported()
Definition host_utility/device_prop.hpp:60
constexpr LoopScheduler make_default_loop_scheduler()
Definition loop_scheduler.hpp:20
Definition ck/stream_config.hpp:10
Definition gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp:93
ck::GridwiseBatchedGemmSoftmaxGemm_Wmma< ADataType, B0DataType, Acc0DataType, B1DataType, Acc1DataType, CShuffleDataType, CDataType, AElementwiseOperation, B0ElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc, B0GridDesc, B1GridDesc, CGridDesc_M_N, MPerBlock, LPerBlock, KPerBlock, AK1, BK1, NPerBlock, LTilePerBlock, L1, MPerWmma, LPerWmma, NPerWmma, MRepeat, LRepeat, NRepeat, BlockSize, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, true, AEnableLds, ABlockLdsAddExtraM, B0BlockTransferThreadClusterLengths_K0_L_K1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_K1, true, B0EnableLds, B0BlockLdsAddExtraL, B1BlockTransferThreadClusterLengths_L0_N_L1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_L1, false, B1EnableLds, B1BlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, NumPrefetch, LoopSched, PipelineVer >::DefaultBlock2CTileMap remove_cvref_t< decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))> DefaultBlock2CTileMap
Definition gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp:682
ck::GridwiseBatchedGemmSoftmaxGemm_Wmma< ADataType, B0DataType, Acc0DataType, B1DataType, Acc1DataType, CShuffleDataType, CDataType, AElementwiseOperation, B0ElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc, B0GridDesc, B1GridDesc, CGridDesc_M_N, MPerBlock, LPerBlock, KPerBlock, AK1, BK1, NPerBlock, LTilePerBlock, L1, MPerWmma, LPerWmma, NPerWmma, MRepeat, LRepeat, NRepeat, BlockSize, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, true, AEnableLds, ABlockLdsAddExtraM, B0BlockTransferThreadClusterLengths_K0_L_K1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_K1, true, B0EnableLds, B0BlockLdsAddExtraL, B1BlockTransferThreadClusterLengths_L0_N_L1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_L1, false, B1EnableLds, B1BlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, NumPrefetch, LoopSched, PipelineVer >::MakeDefaultBlock2CTileMap __host__ static __device__ constexpr auto MakeDefaultBlock2CTileMap(const CGridDesc_M_N &c_grid_desc_m_n, index_t, index_t)
Definition gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp:672
ck::GridwiseBatchedGemmSoftmaxGemm_Wmma< ADataType, B0DataType, Acc0DataType, B1DataType, Acc1DataType, CShuffleDataType, CDataType, AElementwiseOperation, B0ElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc, B0GridDesc, B1GridDesc, CGridDesc_M_N, MPerBlock, LPerBlock, KPerBlock, AK1, BK1, NPerBlock, LTilePerBlock, L1, MPerWmma, LPerWmma, NPerWmma, MRepeat, LRepeat, NRepeat, BlockSize, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, true, AEnableLds, ABlockLdsAddExtraM, B0BlockTransferThreadClusterLengths_K0_L_K1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_K1, true, B0EnableLds, B0BlockLdsAddExtraL, B1BlockTransferThreadClusterLengths_L0_N_L1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_L1, false, B1EnableLds, B1BlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, NumPrefetch, LoopSched, PipelineVer >::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock remove_cvref_t< decltype(MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))> CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
Definition gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp:679
ck::GridwiseBatchedGemmSoftmaxGemm_Wmma< ADataType, B0DataType, Acc0DataType, B1DataType, Acc1DataType, CShuffleDataType, CDataType, AElementwiseOperation, B0ElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc, B0GridDesc, B1GridDesc, CGridDesc_M_N, MPerBlock, LPerBlock, KPerBlock, AK1, BK1, NPerBlock, LTilePerBlock, L1, MPerWmma, LPerWmma, NPerWmma, MRepeat, LRepeat, NRepeat, BlockSize, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, true, AEnableLds, ABlockLdsAddExtraM, B0BlockTransferThreadClusterLengths_K0_L_K1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_K1, true, B0EnableLds, B0BlockLdsAddExtraL, B1BlockTransferThreadClusterLengths_L0_N_L1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_L1, false, B1EnableLds, B1BlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, NumPrefetch, LoopSched, PipelineVer >::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock __host__ static __device__ constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N &c_grid_desc_m_n)
Definition gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp:653
ck::GridwiseBatchedGemmSoftmaxGemm_Wmma< ADataType, B0DataType, Acc0DataType, B1DataType, Acc1DataType, CShuffleDataType, CDataType, AElementwiseOperation, B0ElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc, B0GridDesc, B1GridDesc, CGridDesc_M_N, MPerBlock, LPerBlock, KPerBlock, AK1, BK1, NPerBlock, LTilePerBlock, L1, MPerWmma, LPerWmma, NPerWmma, MRepeat, LRepeat, NRepeat, BlockSize, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, true, AEnableLds, ABlockLdsAddExtraM, B0BlockTransferThreadClusterLengths_K0_L_K1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_K1, true, B0EnableLds, B0BlockLdsAddExtraL, B1BlockTransferThreadClusterLengths_L0_N_L1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_L1, false, B1EnableLds, B1BlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, NumPrefetch, LoopSched, PipelineVer >::CalculateHasMainKBlockLoop __host__ static __device__ constexpr bool CalculateHasMainKBlockLoop(index_t K)
Definition gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp:645
ck::GridwiseBatchedGemmSoftmaxGemm_Wmma< ADataType, B0DataType, Acc0DataType, B1DataType, Acc1DataType, CShuffleDataType, CDataType, AElementwiseOperation, B0ElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc, B0GridDesc, B1GridDesc, CGridDesc_M_N, MPerBlock, LPerBlock, KPerBlock, AK1, BK1, NPerBlock, LTilePerBlock, L1, MPerWmma, LPerWmma, NPerWmma, MRepeat, LRepeat, NRepeat, BlockSize, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, true, AEnableLds, ABlockLdsAddExtraM, B0BlockTransferThreadClusterLengths_K0_L_K1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_K1, true, B0EnableLds, B0BlockLdsAddExtraL, B1BlockTransferThreadClusterLengths_L0_N_L1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_L1, false, B1EnableLds, B1BlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, NumPrefetch, LoopSched, PipelineVer >::CheckValidity __host__ static __device__ constexpr bool CheckValidity(const AGridDesc &a_grid_desc, const B0GridDesc &b0_grid_desc, const B1GridDesc &b1_grid_desc, const CGridDesc_M_N &c_grid_desc_m_n, const Block2CTileMap &block_2_ctile_map)
Definition gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp:511
Definition utility/sequence.hpp:43
Definition utility/integral_constant.hpp:20
Definition device_base.hpp:197
Definition masking_specialization.hpp:57
Definition device_batched_gemm_softmax_gemm_permute.hpp:34
Definition device_multi_query_attention_forward_wmma.hpp:749
ComputeBasePtrOfStridedBatch compute_ptr_offset_of_batch_
Definition device_multi_query_attention_forward_wmma.hpp:879
B0GridDesc_G_L_K b0_grid_desc_g_l_k_
Definition device_multi_query_attention_forward_wmma.hpp:849
B1ElementwiseOperation b1_element_op_
Definition device_multi_query_attention_forward_wmma.hpp:863
const B0DataType * p_b0_grid_
Definition device_multi_query_attention_forward_wmma.hpp:838
AccElementwiseOperation acc_element_op_
Definition device_multi_query_attention_forward_wmma.hpp:862
Argument(const ADataType *p_a_grid, const B0DataType *p_b0_grid, const B1DataType *p_b1_grid, CDataType *p_c_grid, const std::array< void *, NumAcc0Bias > p_acc0_biases, const std::array< void *, NumAcc1Bias > p_acc1_biases, const std::array< index_t, NumDimG+NumDimM+NumDimN > &a_gs_ms_ks_lengths, const std::array< index_t, NumDimG+NumDimM+NumDimN > &a_gs_ms_ks_strides, const std::array< index_t, NumDimG+NumDimM+NumDimN > &b0_gs_ls_ks_lengths, const std::array< index_t, NumDimG+NumDimM+NumDimN > &b0_gs_ls_ks_strides, const std::array< index_t, NumDimG+NumDimM+NumDimN > &b1_gs_ns_ls_lengths, const std::array< index_t, NumDimG+NumDimM+NumDimN > &b1_gs_ns_ls_strides, const std::array< index_t, NumDimG+NumDimM+NumDimN > &c_gs_ms_ns_lengths, const std::array< index_t, NumDimG+NumDimM+NumDimN > &c_gs_ms_ns_strides, const std::array< std::vector< ck::index_t >, NumAcc0Bias > acc0_biases_gs_ms_ls_lengths, const std::array< std::vector< ck::index_t >, NumAcc0Bias > acc0_biases_gs_ms_ls_strides, const std::array< std::vector< ck::index_t >, NumAcc1Bias > acc1_biases_gs_ms_ns_lengths, const std::array< std::vector< ck::index_t >, NumAcc1Bias > acc1_biases_gs_ms_ns_strides, const index_t M01, const index_t N01, AElementwiseOperation a_element_op, B0ElementwiseOperation b0_element_op, AccElementwiseOperation acc_element_op, B1ElementwiseOperation b1_element_op, CElementwiseOperation c_element_op)
Definition device_multi_query_attention_forward_wmma.hpp:750
const B1DataType * p_b1_grid_
Definition device_multi_query_attention_forward_wmma.hpp:839
B0GridDesc b0_grid_desc
Definition device_multi_query_attention_forward_wmma.hpp:844
B1GridDesc b1_grid_desc
Definition device_multi_query_attention_forward_wmma.hpp:845
CGridDesc_M_N c_grid_desc_m_n_
Definition device_multi_query_attention_forward_wmma.hpp:846
std::array< index_t, NumDimG+NumDimM+NumDimN > b0_lz_kz_strides_
Definition device_multi_query_attention_forward_wmma.hpp:873
const ADataType * p_a_grid_
Definition device_multi_query_attention_forward_wmma.hpp:837
GridwiseOp::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_
Definition device_multi_query_attention_forward_wmma.hpp:854
CGridDesc_G_M_N c_grid_desc_g_m_n_
Definition device_multi_query_attention_forward_wmma.hpp:851
index_t batch_count_
Definition device_multi_query_attention_forward_wmma.hpp:877
AGridDesc a_grid_desc
Definition device_multi_query_attention_forward_wmma.hpp:843
C0MatrixMask c0_matrix_mask_
Definition device_multi_query_attention_forward_wmma.hpp:867
GridwiseOp::DefaultBlock2CTileMap block_2_ctile_map_
Definition device_multi_query_attention_forward_wmma.hpp:857
std::array< index_t, NumDimG+NumDimM+NumDimN > a_mz_kz_strides_
Definition device_multi_query_attention_forward_wmma.hpp:872
std::array< index_t, NumDimG+NumDimM+NumDimN > c_mz_nz_strides_
Definition device_multi_query_attention_forward_wmma.hpp:875
B1GridDesc_G_N_L b1_grid_desc_g_n_l_
Definition device_multi_query_attention_forward_wmma.hpp:850
std::array< index_t, NumDimG+NumDimM+NumDimN > raw_lengths_mz_lz_kz_nz_
Definition device_multi_query_attention_forward_wmma.hpp:871
CDataType * p_c_grid_
Definition device_multi_query_attention_forward_wmma.hpp:840
B0ElementwiseOperation b0_element_op_
Definition device_multi_query_attention_forward_wmma.hpp:861
CElementwiseOperation c_element_op_
Definition device_multi_query_attention_forward_wmma.hpp:864
AElementwiseOperation a_element_op_
Definition device_multi_query_attention_forward_wmma.hpp:860
AGridDesc_G_M_K a_grid_desc_g_m_k_
Definition device_multi_query_attention_forward_wmma.hpp:848
std::array< index_t, NumDimG+NumDimM+NumDimN > b1_nz_lz_strides_
Definition device_multi_query_attention_forward_wmma.hpp:874
Definition device_multi_query_attention_forward_wmma.hpp:415
__host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const
Definition device_multi_query_attention_forward_wmma.hpp:437
__host__ __device__ ComputeBasePtrOfStridedBatch(const AGridDesc_G_M_K &a_grid_desc_g_m_k, const B0GridDesc_G_L_K &b0_grid_desc_g_l_k, const B1GridDesc_G_N_L &b1_grid_desc_g_n_l, const CGridDesc_G_M_N &c_grid_desc_g_m_n)
Definition device_multi_query_attention_forward_wmma.hpp:416
__host__ __device__ constexpr long_index_t GetB0BasePtr(index_t g_idx) const
Definition device_multi_query_attention_forward_wmma.hpp:432
__host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const
Definition device_multi_query_attention_forward_wmma.hpp:442
__host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const
Definition device_multi_query_attention_forward_wmma.hpp:427
Definition device_multi_query_attention_forward_wmma.hpp:883
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_multi_query_attention_forward_wmma.hpp:886
DeviceOp::RawArg Argument
Definition device_multi_query_attention_forward_wmma.hpp:884
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_multi_query_attention_forward_wmma.hpp:939
Definition device_multi_query_attention_forward_wmma.hpp:531
index_t G1_
Definition device_multi_query_attention_forward_wmma.hpp:572
const B0DataType * p_b0_grid_
Definition device_multi_query_attention_forward_wmma.hpp:562
float alpha_
Definition device_multi_query_attention_forward_wmma.hpp:573
const ADataType * p_a_grid_
Definition device_multi_query_attention_forward_wmma.hpp:561
bool input_permute_
Definition device_multi_query_attention_forward_wmma.hpp:574
bool output_permute_
Definition device_multi_query_attention_forward_wmma.hpp:575
const B1DataType * p_b1_grid_
Definition device_multi_query_attention_forward_wmma.hpp:563
index_t K_
Definition device_multi_query_attention_forward_wmma.hpp:569
index_t M_
Definition device_multi_query_attention_forward_wmma.hpp:567
CDataType * p_c_grid_
Definition device_multi_query_attention_forward_wmma.hpp:564
index_t N_
Definition device_multi_query_attention_forward_wmma.hpp:568
index_t O_
Definition device_multi_query_attention_forward_wmma.hpp:570
RawArg(const ADataType *p_a_grid, const B0DataType *p_b0_grid, const B1DataType *p_b1_grid, CDataType *p_c_grid, index_t M, index_t N, index_t K, index_t O, index_t G0, index_t G1, float alpha, bool input_permute, bool output_permute)
Definition device_multi_query_attention_forward_wmma.hpp:532
index_t G0_
Definition device_multi_query_attention_forward_wmma.hpp:571
Definition device_multi_query_attention_forward_wmma.hpp:264
decltype(Transform::MakeAGridDescriptor_G_M_K({}, {})) AGridDesc_G_M_K
Definition device_multi_query_attention_forward_wmma.hpp:396
static constexpr auto I5
Definition device_multi_query_attention_forward_wmma.hpp:288
static constexpr auto WmmaK
Definition device_multi_query_attention_forward_wmma.hpp:291
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b0, const void *p_b1, void *p_c, const std::array< void *, NumAcc0Bias > p_acc0_biases, const std::array< void *, NumAcc1Bias > p_acc1_biases, const std::vector< index_t > &a_gs_ms_ks_lengths, const std::vector< index_t > &a_gs_ms_ks_strides, const std::vector< index_t > &b0_gs_ls_ks_lengths, const std::vector< index_t > &b0_gs_ls_ks_strides, const std::vector< index_t > &b1_gs_ns_ls_lengths, const std::vector< index_t > &b1_gs_ns_ls_strides, const std::vector< index_t > &c_gs_ms_ns_lengths, const std::vector< index_t > &c_gs_ms_ns_strides, const std::array< std::vector< ck::index_t >, NumAcc0Bias > acc0_biases_gs_ms_ls_lengths, const std::array< std::vector< ck::index_t >, NumAcc0Bias > acc0_biases_gs_ms_ls_strides, const std::array< std::vector< ck::index_t >, NumAcc1Bias > acc1_biases_gs_ms_ns_lengths, const std::array< std::vector< ck::index_t >, NumAcc1Bias > acc1_biases_gs_ms_ns_strides, AElementwiseOperation a_element_op, B0ElementwiseOperation b0_element_op, AccElementwiseOperation acc_element_op, B1ElementwiseOperation b1_element_op, CElementwiseOperation c_element_op) override
Definition device_multi_query_attention_forward_wmma.hpp:1094
static constexpr index_t NumAcc0Bias
Definition device_multi_query_attention_forward_wmma.hpp:268
__host__ static __device__ auto MakeB1GridDescriptor(const std::array< index_t, NumDimG+NumDimM+NumDimN > &b1_gs_ns_ls_lengths_vec, const std::array< index_t, NumDimG+NumDimM+NumDimN > &b1_gs_ns_ls_strides_vec)
Definition device_multi_query_attention_forward_wmma.hpp:367
static constexpr auto NWaves
Definition device_multi_query_attention_forward_wmma.hpp:295
__host__ static __device__ auto MakeAGridDescriptor(const std::array< index_t, NumDimG+NumDimM+NumDimN > &a_gs_ms_ks_lengths_vec, const std::array< index_t, NumDimG+NumDimM+NumDimN > &a_gs_ms_ks_strides_vec)
Definition device_multi_query_attention_forward_wmma.hpp:318
static constexpr auto B0EnableLds_auto
Definition device_multi_query_attention_forward_wmma.hpp:298
static constexpr index_t NumDimGemm1K
Definition device_multi_query_attention_forward_wmma.hpp:279
decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {})) B1GridDesc_G_N_L
Definition device_multi_query_attention_forward_wmma.hpp:398
static constexpr bool IsValidCompilationParameter()
Definition device_multi_query_attention_forward_wmma.hpp:946
GridwiseBatchedGemmSoftmaxGemm_Wmma< ADataType, B0DataType, Acc0DataType, B1DataType, Acc1DataType, CShuffleDataType, CDataType, AElementwiseOperation, B0ElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc, B0GridDesc, B1GridDesc, CGridDesc_M_N, MPerBlock, LPerBlock, KPerBlock, AK1, BK1, NPerBlock, LTilePerBlock, L1, MPerWmma, LPerWmma, NPerWmma, MRepeat, LRepeat, NRepeat, BlockSize, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, true, AEnableLds, ABlockLdsAddExtraM, B0BlockTransferThreadClusterLengths_K0_L_K1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_K1, true, B0EnableLds, B0BlockLdsAddExtraL, B1BlockTransferThreadClusterLengths_L0_N_L1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_L1, false, B1EnableLds, B1BlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, NumPrefetch, LoopSched, PipelineVer > GridwiseOp
Definition device_multi_query_attention_forward_wmma.hpp:455
DeviceMultiQueryAttentionForward_Wmma DeviceOp
Definition device_multi_query_attention_forward_wmma.hpp:281
static constexpr index_t NumDimGemm0N
Definition device_multi_query_attention_forward_wmma.hpp:275
static constexpr index_t NumDimGemm1N
Definition device_multi_query_attention_forward_wmma.hpp:278
static constexpr auto AEnableLds
Definition device_multi_query_attention_forward_wmma.hpp:305
static constexpr auto B0EnableLds
Definition device_multi_query_attention_forward_wmma.hpp:306
decltype(MakeB1GridDescriptor({}, {})) B1GridDesc
Definition device_multi_query_attention_forward_wmma.hpp:394
static bool IsSupportedArgument(const RawArg &arg)
Definition device_multi_query_attention_forward_wmma.hpp:596
decltype(MakeB0GridDescriptor({}, {})) B0GridDesc
Definition device_multi_query_attention_forward_wmma.hpp:393
static constexpr index_t NumDimGemm1M
Definition device_multi_query_attention_forward_wmma.hpp:277
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_multi_query_attention_forward_wmma.hpp:742
static constexpr auto I2
Definition device_multi_query_attention_forward_wmma.hpp:285
decltype(Transform::MakeCGridDescriptor_G_M_N({}, {})) CGridDesc_G_M_N
Definition device_multi_query_attention_forward_wmma.hpp:399
static constexpr auto LWaves
Definition device_multi_query_attention_forward_wmma.hpp:294
static constexpr auto B1EnableLds
Definition device_multi_query_attention_forward_wmma.hpp:307
static constexpr auto I1
Definition device_multi_query_attention_forward_wmma.hpp:284
static constexpr auto I3
Definition device_multi_query_attention_forward_wmma.hpp:286
decltype(MakeAGridDescriptor({}, {})) AGridDesc
Definition device_multi_query_attention_forward_wmma.hpp:392
decltype(Transform::MakeCGridDescriptor_M_N({}, {})) CGridDesc_M_N
Definition device_multi_query_attention_forward_wmma.hpp:395
static constexpr auto I6
Definition device_multi_query_attention_forward_wmma.hpp:289
static constexpr auto MWaves
Definition device_multi_query_attention_forward_wmma.hpp:293
static constexpr auto AEnableLds_manu
Definition device_multi_query_attention_forward_wmma.hpp:301
static constexpr auto AEnableLds_auto
Definition device_multi_query_attention_forward_wmma.hpp:297
static constexpr auto B0EnableLds_manu
Definition device_multi_query_attention_forward_wmma.hpp:302
__host__ static __device__ auto MakeB0GridDescriptor(const std::array< index_t, NumDimG+NumDimM+NumDimN > &b0_gs_ls_ks_lengths_vec, const std::array< index_t, NumDimG+NumDimM+NumDimN > &b0_gs_ls_ks_strides_vec)
Definition device_multi_query_attention_forward_wmma.hpp:342
static constexpr auto I0
Definition device_multi_query_attention_forward_wmma.hpp:283
TransformBatchedContractionContractionToBatchedGemmGemm_Wmma< Sequence< NumDimG, NumDimM, NumDimL, NumDimK, NumDimN >, Sequence< MPerBlock, LPerBlock, KPerBlock, NPerBlock >, GemmSpec, ASpec, B0Spec, B1Spec, CSpec > Transform
Definition device_multi_query_attention_forward_wmma.hpp:309
static constexpr index_t NumDimGemm0M
Definition device_multi_query_attention_forward_wmma.hpp:274
__host__ __device__ static constexpr auto make_MaskOutPredicate()
Definition device_multi_query_attention_forward_wmma.hpp:401
static constexpr index_t NumDimGemm0K
Definition device_multi_query_attention_forward_wmma.hpp:276
std::string GetTypeString() const override
Definition device_multi_query_attention_forward_wmma.hpp:1195
static auto MakeInvoker()
Definition device_multi_query_attention_forward_wmma.hpp:1186
static auto MakeArgument(const ADataType *p_a, const B0DataType *p_b0, const B1DataType *p_b1, CDataType *p_c, index_t M, index_t N, index_t K, index_t O, index_t G0, index_t G1, float alpha, bool input_permute, bool output_permute)
Definition device_multi_query_attention_forward_wmma.hpp:578
static constexpr auto B1EnableLds_auto
Definition device_multi_query_attention_forward_wmma.hpp:299
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_multi_query_attention_forward_wmma.hpp:1189
C0MatrixMask_impl< decltype(make_MaskOutPredicate())> C0MatrixMask
Definition device_multi_query_attention_forward_wmma.hpp:412
static constexpr index_t NumAcc1Bias
Definition device_multi_query_attention_forward_wmma.hpp:269
decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {})) B0GridDesc_G_L_K
Definition device_multi_query_attention_forward_wmma.hpp:397
static constexpr auto B1EnableLds_manu
Definition device_multi_query_attention_forward_wmma.hpp:303
static constexpr auto I4
Definition device_multi_query_attention_forward_wmma.hpp:287
Definition masking_specialization.hpp:29
Definition masking_specialization.hpp:43