93 const BTensorType& b_local_tile_tensor,
94 CTensorType& c_reg_tensor)
96 constexpr auto I3 = Number<3>{};
98 static_assert(ATensorType::TensorBufferAddressSpace == MemoryTypeEnum::Lds);
99 static_assert(BTensorType::TensorBufferAddressSpace == MemoryTypeEnum::Lds);
100 static_assert(CTensorType::TensorBufferAddressSpace == MemoryTypeEnum::Vgpr);
101 static_assert(is_same_v<DataType, typename ATensorType::TensorElementType>);
102 static_assert(is_same_v<DataType, typename BTensorType::TensorElementType>);
104 constexpr bool is_integer =
105 is_same_v<DataType, int8_t> || is_same_v<DataType, int16_t> || is_same_v<DataType, int32_t>;
106 using GemmAccDataType = std::conditional_t<is_integer, int32_t, float>;
108 using ATileLayout = remove_cvref_t<
decltype(
layout(a_local_tile_tensor))>;
109 using BTileLayout = remove_cvref_t<
decltype(
layout(b_local_tile_tensor))>;
111 static_assert(
typename ATileLayout::LayoutShape{}.Size() ==
112 typename BTileLayout::LayoutShape{}.Size());
113 constexpr bool is_3d_desc =
typename ATileLayout::LayoutShape{}.Size() == I3;
115 using ABlockDesc_K0_M_K1_Type =
116 conditional_t<is_3d_desc,
117 typename ATileLayout::LayoutUnrolledDescriptorType,
118 decltype(detail::GetBlockDescriptor<GemmTraits::K1, ATileLayout>())>;
119 using BBlockDesc_K0_N_K1_Type =
120 conditional_t<is_3d_desc,
121 typename BTileLayout::LayoutUnrolledDescriptorType,
122 decltype(detail::GetBlockDescriptor<GemmTraits::K1, BTileLayout>())>;
124 BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
128 ABlockDesc_K0_M_K1_Type,
129 BBlockDesc_K0_N_K1_Type,
132 GemmTraits::MXdlPerWave,
133 GemmTraits::NXdlPerWave,
135 blockwise_gemm_xdl_op{};
137 blockwise_gemm_xdl_op.Run(
138 a_local_tile_tensor.GetBuffer(), b_local_tile_tensor.GetBuffer(), c_reg_tensor.GetBuffer());
179 constexpr auto I0 = Number<0>{};
180 constexpr auto I1 = Number<1>{};
181 constexpr auto I2 = Number<2>{};
182 constexpr auto I3 = Number<3>{};
183 constexpr auto I4 = Number<4>{};
184 constexpr auto I5 = Number<5>{};
185 constexpr auto I6 = Number<6>{};
186 constexpr auto I7 = Number<7>{};
188 static_assert(
typename ATileLayout::LayoutShape{}.Size() ==
189 typename BTileLayout::LayoutShape{}.Size());
191 constexpr bool is_integer =
192 is_same_v<DataType, int8_t> || is_same_v<DataType, int16_t> || is_same_v<DataType, int32_t>;
193 using GemmAccDataType = std::conditional_t<is_integer, int32_t, float>;
195 constexpr bool is_3d_desc =
typename ATileLayout::LayoutShape{}.Size() == I3;
196 using ABlockDesc_K0_M_K1_Type =
197 conditional_t<is_3d_desc,
198 typename ATileLayout::LayoutUnrolledDescriptorType,
199 decltype(detail::GetBlockDescriptor<GemmTraits::K1, ATileLayout>())>;
200 using BBlockDesc_K0_N_K1_Type =
201 conditional_t<is_3d_desc,
202 typename BTileLayout::LayoutUnrolledDescriptorType,
203 decltype(detail::GetBlockDescriptor<GemmTraits::K1, BTileLayout>())>;
205 using BlockwiseGemmXdlops =
206 BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
210 ABlockDesc_K0_M_K1_Type,
211 BBlockDesc_K0_N_K1_Type,
214 GemmTraits::MXdlPerWave,
215 GemmTraits::NXdlPerWave,
218 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
219 BlockwiseGemmXdlops::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
220 constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I0);
221 constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I1);
222 constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I2);
223 constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I3);
224 constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I4);
225 constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I5);
226 constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I6);
227 constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I7);
230 const auto c_thread_mtx_on_block =
231 BlockwiseGemmXdlops::CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
233 const index_t m_thread_data_on_grid =
234 c_local_tile_tensor.GetMultiIdxOffsets()[I0] + c_thread_mtx_on_block[I0];
236 const index_t n_thread_data_on_grid =
237 c_local_tile_tensor.GetMultiIdxOffsets()[I1] + c_thread_mtx_on_block[I1];
239 const auto m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor = make_single_stage_tensor_adaptor(
240 make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
241 make_tuple(Sequence<0, 1, 2, 3, 4>{}),
242 make_tuple(Sequence<0>{}));
244 const auto m_thread_data_on_grid_idx =
245 m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
246 make_multi_index(m_thread_data_on_grid));
248 const auto n_thread_data_on_grid_to_n0_n1_n2_adaptor =
249 make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
250 make_tuple(Sequence<0, 1, 2>{}),
251 make_tuple(Sequence<0>{}));
253 const auto n_thread_data_on_grid_idx =
254 n_thread_data_on_grid_to_n0_n1_n2_adaptor.CalculateBottomIndex(
255 make_multi_index(n_thread_data_on_grid));
257 const auto partition_shape = make_tuple(M0, N0, I1, I1, M2, I1, M4, I1);
259 const auto partition_desc = BlockwiseGemmXdlops::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(
260 layout(c_local_tile_tensor).GetUnrolledDescriptor());
262 const auto lower_upper_dims =
263 generate_tuple([&](
auto i) {
return Sequence<i.value>{}; }, Number<8>{});
265 auto sliced_desc = transform_tensor_descriptor(
268 make_slice_transform(partition_shape.At(Number<0>{}),
269 m_thread_data_on_grid_idx[I0],
270 partition_shape.At(Number<0>{}) + m_thread_data_on_grid_idx[I0]),
271 make_slice_transform(partition_shape.At(Number<1>{}),
272 n_thread_data_on_grid_idx[I0],
273 partition_shape.At(Number<1>{}) + n_thread_data_on_grid_idx[I0]),
274 make_slice_transform(partition_shape.At(Number<2>{}),
275 m_thread_data_on_grid_idx[I1],
276 partition_shape.At(Number<2>{}) + m_thread_data_on_grid_idx[I1]),
277 make_slice_transform(partition_shape.At(Number<3>{}),
278 n_thread_data_on_grid_idx[I1],
279 partition_shape.At(Number<3>{}) + n_thread_data_on_grid_idx[I1]),
280 make_slice_transform(partition_shape.At(Number<4>{}),
281 m_thread_data_on_grid_idx[I2],
282 partition_shape.At(Number<4>{}) + m_thread_data_on_grid_idx[I2]),
283 make_slice_transform(partition_shape.At(Number<5>{}),
284 m_thread_data_on_grid_idx[I3],
285 partition_shape.At(Number<5>{}) + m_thread_data_on_grid_idx[I3]),
286 make_slice_transform(partition_shape.At(Number<6>{}),
287 m_thread_data_on_grid_idx[I4],
288 partition_shape.At(Number<6>{}) + m_thread_data_on_grid_idx[I4]),
289 make_slice_transform(partition_shape.At(Number<7>{}),
290 n_thread_data_on_grid_idx[I2],
291 partition_shape.At(Number<7>{}) + n_thread_data_on_grid_idx[I2])),
295 const auto partition_layout =
296 Layout<remove_reference_t<
decltype(partition_shape)>,
decltype(sliced_desc)>(
297 partition_shape, sliced_desc);
299 c_local_tile_tensor.GetPointer(), partition_layout);
300 return partition_tensor;
337 constexpr auto I0 = Number<0>{};
338 constexpr auto I1 = Number<1>{};
339 constexpr auto I2 = Number<2>{};
340 constexpr auto I3 = Number<3>{};
341 constexpr auto I4 = Number<4>{};
342 constexpr auto I5 = Number<5>{};
343 constexpr auto I6 = Number<6>{};
344 constexpr auto I7 = Number<7>{};
346 static_assert(
typename ATileLayout::LayoutShape{}.Size() ==
347 typename BTileLayout::LayoutShape{}.Size());
349 constexpr bool is_integer =
350 is_same_v<DataType, int8_t> || is_same_v<DataType, int16_t> || is_same_v<DataType, int32_t>;
351 using GemmAccDataType = std::conditional_t<is_integer, int32_t, float>;
353 constexpr bool is_3d_desc =
typename ATileLayout::LayoutShape{}.Size() == I3;
354 using ABlockDesc_K0_M_K1_Type =
355 conditional_t<is_3d_desc,
356 typename ATileLayout::LayoutUnrolledDescriptorType,
357 decltype(detail::GetBlockDescriptor<GemmTraits::K1, ATileLayout>())>;
358 using BBlockDesc_K0_N_K1_Type =
359 conditional_t<is_3d_desc,
360 typename BTileLayout::LayoutUnrolledDescriptorType,
361 decltype(detail::GetBlockDescriptor<GemmTraits::K1, BTileLayout>())>;
363 using BlockwiseGemmXdlops =
364 BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
368 ABlockDesc_K0_M_K1_Type,
369 BBlockDesc_K0_N_K1_Type,
372 GemmTraits::MXdlPerWave,
373 GemmTraits::NXdlPerWave,
376 constexpr auto vgpr_desc = BlockwiseGemmXdlops::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
377 const auto vgpr_shape = make_tuple(vgpr_desc.GetLengths()[I0],
378 vgpr_desc.GetLengths()[I1],
379 vgpr_desc.GetLengths()[I2],
380 vgpr_desc.GetLengths()[I3],
381 vgpr_desc.GetLengths()[I4],
382 vgpr_desc.GetLengths()[I5],
383 vgpr_desc.GetLengths()[I6],
384 vgpr_desc.GetLengths()[I7]);
385 const auto vgpr_layout =
Layout<remove_reference_t<
decltype(vgpr_shape)>,
decltype(vgpr_desc)>(
386 vgpr_shape, vgpr_desc);
388 constexpr index_t ScalarPerVector = BlockwiseGemmXdlops::xdlops_gemm.GetRegSizePerXdlops();
389 using VgprVectorType =
typename vector_type<GemmAccDataType, ScalarPerVector>::type;
390 return ck::wrapper::make_register_tensor<ck::wrapper::MemoryTypeEnum::Vgpr, VgprVectorType>(