31template <
typename SrcDatas,
35 typename ElementwiseOperation,
37 typename SliceLengths,
38 typename SrcDimAccessOrder,
39 typename DstDimAccessOrder,
42 typename SrcScalarPerVectors,
44 typename SrcResetCoordinateAfterRunFlags,
45 typename DstResetCoordinateAfterRunFlags,
48 bool OutputScatter =
true,
69 template <
typename Descs,
71 enable_if_t<Descs::Size() == Indices::Size(),
bool> =
false>
100 const SrcDescs& src_descs,
102 const DstDescs& dst_descs,
104 const ElementwiseOperation& element_op)
107 element_op_(element_op)
110 "wrong! cannot evenly divide");
113 "wrong! cannot evenly divide");
116 template <
typename Indices, enable_if_t<SrcDescs::Size() == Indices::Size(),
bool> = false>
118 const Indices& src_slice_origin_idxs)
125 template <
typename Indices, enable_if_t<DstDescs::Size() == Indices::Size(),
bool> = false>
127 const Indices& dst_slice_origin_idxs)
134 template <
typename DataTypes, index_t ScalarPerVector>
137 auto data_types = DataTypes{};
139 constexpr index_t num = data_types.Size();
152 template <
typename SrcBuffers,
154 enable_if_t<SrcDescs::Size() == SrcBuffers::Size(),
bool> =
false>
155 __device__
void RunRead(
const SrcDescs& src_descs,
156 const SrcBuffers& src_bufs,
167 static_for<0, nSrc, 1>{}([&](
auto i) {
168 using src_vector_t =
typename remove_cvref_t<
decltype(src_vectors[i])>::type;
170 const bool is_src_valid =
174 oob_val = oob_val & is_src_valid;
175 src_vectors(i).template AsType<src_vector_t>()(
I0) =
176 src_bufs[i].
template Get<src_vector_t>(src_coords_[i].GetOffset(),
true);
179 constexpr auto get_elem_op_vec_len = []() {
182 if constexpr(
decltype(element_op_)::is_pack8_invocable)
187 if constexpr(
decltype(element_op_)::is_pack4_invocable)
192 if constexpr(
decltype(element_op_)::is_pack2_invocable)
198 constexpr index_t elem_op_vec_len = get_elem_op_vec_len();
205 [&](
auto iSrc) ->
const auto& {
208 using elem_op_vec_t =
typename vector_type<SrcData, elem_op_vec_len>::type;
210 return src_vectors[iSrc].template AsType<elem_op_vec_t>()[i];
217 [&](
auto iDst) ->
auto& {
220 using elem_op_vec_t =
typename vector_type<DstData, elem_op_vec_len>::type;
222 return elm_vectors(iDst).template AsType<elem_op_vec_t>()(i);
234 unpack2(element_op_, dst_data_refs, src_data_refs);
237 elm_vectors_tuple_(thread_scratch_id)(iAccess) = elm_vectors;
238 oob_vectors_tuple_(thread_scratch_id)(iAccess) = oob_val;
241 if constexpr(iAccess.value != src_num_access - 1)
245 static_for<0, nSrc, 1>{}([&](
auto i) {
254 static_for<0, nSrc, 1>{}([&](
auto i) {
255 if constexpr(SrcResetCoordinateAfterRunFlags::At(i))
257 const auto src_reset_step =
266 template <index_t ThreadScratchId = 0>
271 auto elm_vectors = elm_vectors_tuple_[thread_scratch_id][iAccess];
272 auto oob_val = oob_vectors_tuple_[thread_scratch_id][iAccess];
274 static_for<0, nDst, 1>{}([&](
auto i) {
275 using elm_vector_t =
typename remove_cvref_t<
decltype(elm_vectors[i])>::type;
276 elm_vectors(i).template AsType<elm_vector_t>()(
I0) =
277 oob_val ? elm_vectors(i).template AsType<elm_vector_t>()[
I0] : elm_vector_t{0};
280 elm_vectors_tuple_(thread_scratch_id)(iAccess) = elm_vectors;
285 template <index_t ThreadScratchId = 0>
291 using ElmThreadScratch =
297 using DstThreadScratch =
304 ElmThreadScratch elm_thread_scratch_;
305 DstThreadScratch dst_thread_scratch_;
307 elm_thread_scratch_.data_ =
308 bit_cast<
decltype(elm_thread_scratch_.data_)>(elm_vectors_tuple_[thread_scratch_id]);
310 if constexpr(SrcVectorDim != DstVectorDim &&
311 ((is_same<half_t, remove_cvref_t<DstData>>
::value &&
313 (is_same<f8_t, remove_cvref_t<DstData>>
::value &&
315 (is_same<int8_t, remove_cvref_t<DstData>>
::value &&
328 detail::lambda_scalar_step_in_vector<SrcVectorDim>{},
Number<nDim>{});
331 detail::lambda_scalar_step_in_vector<DstVectorDim>{},
Number<nDim>{});
334 detail::lambda_scalar_per_access_for_src_and_dst<SrcVectorDim,
337 DstScalarPerVector>{},
340 constexpr auto access_lengths = SliceLengths{} / scalar_per_access;
342 static_ford<
decltype(access_lengths)>{}([&](
auto access_idx) {
343 constexpr auto data_idx = access_idx * scalar_per_access;
354 [&](
auto i) ->
const src_vector_t& {
356 return elm_thread_scratch_.GetVectorTypeReference(
357 data_idx_seq + i * dst_scalar_step_in_vector);
364 [&](
auto i) -> dst_vector_t& {
366 return dst_thread_scratch_.GetVectorTypeReference(
367 data_idx_seq + i * src_scalar_step_in_vector);
372 transpose_vectors<DstData, DstScalarPerVector, SrcScalarPerVector>{}(
373 src_vector_refs, dst_vector_refs);
378 static_ford<SliceLengths>{}(
379 [&](
auto idx) { dst_thread_scratch_(idx) = elm_thread_scratch_[idx]; });
387 template <
typename DstBuffers,
389 enable_if_t<DstDescs::Size() == 1 && DstBuffers::Size() == 1,
bool> =
false>
390 __device__
void RunWrite(
const DstDescs& dst_descs,
400 auto dst_vectors = dst_vectors_tuple_[thread_scratch_id][iAccess];
401 IndexType scatter_offset = 0;
402 if constexpr(OutputScatter)
404 constexpr auto iScatter =
409 static_for<0, nDst, 1>{}([&](
auto i) {
410 using dst_vector_t =
typename remove_cvref_t<
decltype(dst_vectors[i])>::type;
411 IndexType dst_offset = scatter_offset + (dst_coords_[i].GetOffset());
412 const bool is_dst_valid = dst_offset < dst_descs[i].GetElementSpaceSize();
415 dst_bufs(i).template Update<DstInMemOp, dst_vector_t>(
416 dst_offset, is_dst_valid, dst_vectors[i].
template AsType<dst_vector_t>()[
I0]);
420 if constexpr(iAccess.value != dst_num_access - 1)
424 auto forward_step_scatter = [&]()
constexpr {
427 static_for<0, nDim, 1>{}([&](
auto i) {
428 step_(i) = (i.value == ScatterDim && OutputScatter) ? 0 : forward_step[i];
433 static_for<0, nDst, 1>{}([&](
auto i) {
442 static_for<0, nDst, 1>{}([&](
auto i) {
443 if constexpr(DstResetCoordinateAfterRunFlags::At(i))
445 const auto dst_reset_step =
457 template <
typename SrcBuffers,
459 enable_if_t<SrcDescs::Size() == SrcBuffers::Size() &&
460 DstDescs::Size() == DstBuffers::Size(),
462 __device__
void Run(
const SrcDescs& src_descs,
463 const SrcBuffers& src_bufs,
464 const DstDescs& dst_descs,
469 RunWrite(dst_descs, dst_bufs, scatter_offsets);
474 if constexpr(src_num_access == 0)
486 if constexpr(dst_num_access == 0)
492 constexpr auto reset_step =
494 auto reset_step_scatter = [&]()
constexpr {
498 (i.value == ScatterDim && OutputScatter) ? 0 : reset_step[
Number<i>{}];
503 return reset_step_scatter;
519 constexpr auto desc0 =
525 if constexpr(i == SrcVectorDim)
528 make_tuple(src_access_lengths_and_vector_length[i],
540 if constexpr(i == SrcVectorDim)
551 constexpr auto up_dim_idss =
569 constexpr auto desc0 =
575 if constexpr(i == DstVectorDim)
578 make_tuple(dst_access_lengths_and_vector_length[i],
590 if constexpr(i == DstVectorDim)
601 constexpr auto up_dim_idss =
608 template <index_t ISrc>
611 const Index& src_slice_origin_step_idx)
614 const auto adjusted_step_idx =
615 SrcResetCoordinateAfterRunFlags::At(iSrc)
616 ? src_slice_origin_step_idx
626 template <index_t IDst>
629 const Index& dst_slice_origin_step_idx)
632 const auto adjusted_step_idx =
633 DstResetCoordinateAfterRunFlags::At(iDst)
634 ? dst_slice_origin_step_idx
637 auto adjusted_step_idx_scatter = [&]() {
641 (i.value == ScatterDim && OutputScatter) ? 0 : adjusted_step_idx[
Number<i>{}];
647 const auto adjusted_step =
672 const ElementwiseOperation element_op_;
__host__ __device__ constexpr T min(T x)
Definition utility/math.hpp:116
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
decltype(ck::declval< T & >().is_pack8_invocable) is_pack8_invocable_t
Definition is_detected.hpp:43
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition utility/statically_indexed_array.hpp:45
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_tensor_coordinate_step(const TensorDesc &, const VisibleIndex &idx_diff_visible, UpdateLowerIndexHack)
Definition tensor_description/tensor_descriptor.hpp:444
__host__ __device__ constexpr void move_tensor_coordinate(const TensorDesc &tensor_desc, TensorCoord &coord, const TensorCoordStep &coord_step)
Definition tensor_description/tensor_descriptor.hpp:508
__host__ __device__ constexpr auto container_push_back(const Array< TData, NSize > &a, const TData &x)
Definition utility/container_helper.hpp:18
InMemoryDataOperationEnum
Definition ck.hpp:277
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
decltype(ck::declval< T & >().is_pack4_invocable) is_pack4_invocable_t
Definition is_detected.hpp:40
__host__ __device__ constexpr bool coordinate_has_valid_offset_assuming_visible_index_is_valid(const TensorDesc &tensor_desc, const TensorCoord &coord)
Definition tensor_description/tensor_descriptor.hpp:560
typename detail::detector< nonesuch, void, Op, Args... >::value_t is_detected
Definition is_detected.hpp:34
integral_constant< index_t, N > Number
Definition number.hpp:12
@ Vgpr
Definition amd_address_space.hpp:20
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
__host__ __device__ constexpr auto sequence_to_tuple_of_number(Sequence< Is... >)
Definition utility/container_helper.hpp:380
__host__ __device__ constexpr auto generate_sequence_v2(F &&f, Number< N >)
Definition sequence_helper.hpp:25
__host__ __device__ constexpr auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:84
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
decltype(ck::declval< T & >().is_pack2_invocable) is_pack2_invocable_t
Definition is_detected.hpp:37
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto generate_sequence(F, Number< N >)
Definition sequence_helper.hpp:18
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__host__ __device__ constexpr Y bit_cast(const X &x)
Definition type.hpp:306
typename remove_cv< T >::type remove_cv_t
Definition type.hpp:295
__host__ __device__ constexpr auto unpack2(F &&f, X &&x, Y &&y)
Definition functional4.hpp:55
typename std::enable_if< B, T >::type enable_if_t
Definition enable_if.hpp:27
__host__ __device__ constexpr auto make_tensor_coordinate(const TensorDesc &tensor_desc, const VisibleIndex &idx_visible)
Definition tensor_description/tensor_descriptor.hpp:407
__host__ __device__ constexpr auto generate_tie(F &&f, Number< N >)
Definition tuple_helper.hpp:34
Array< index_t, N > MultiIndex
Definition array_multi_index.hpp:12
typename vector_type_maker< T, N >::type vector_type_maker_t
Definition dtype_vector.hpp:54
const GenericPointer< typename T::ValueType > T2 value
Definition pointer.h:1697
Definition utility/sequence.hpp:43
Definition tensor_space_filling_curve.hpp:20
static __device__ __host__ constexpr auto GetStepBetween(Number< AccessIdx1dBegin >, Number< AccessIdx1dEnd >)
Definition tensor_space_filling_curve.hpp:52
__host__ static __device__ constexpr index_t GetNumOfAccess()
Definition tensor_space_filling_curve.hpp:41
static __device__ __host__ constexpr Index GetIndex(Number< AccessIdx1d >)
Definition tensor_space_filling_curve.hpp:81
static __device__ __host__ constexpr auto GetForwardStep(Number< AccessIdx1d >)
Definition tensor_space_filling_curve.hpp:66
MultiIndex< nDim > Index
Definition tensor_space_filling_curve.hpp:23
ck::ThreadwiseTensorSliceTransfer_v7r3_scatter< SrcDatas, DstDatas, SrcDescs, DstDescs, ElementwiseOperation, DstInMemOps, decltype(thread_slice_lengths), SrcDimAccessOrder, DstDimAccessOrder, SrcVectorDim, DstVectorDim, SrcScalarPerVectors, DstScalarPerVector, ThreadTransferSrcResetCoordinateAfterRunFlags, ThreadTransferDstResetCoordinateAfterRunFlags, IndexType, ScatterDim, OutputScatter, ScatterWeightIdx, NumThreadScratch >::scatter_num static constexpr index_t scatter_num
Definition threadwise_tensor_slice_transfer_v7r3_scatter.hpp:66
__device__ void SetDstSliceOrigins(const DstDescs &dst_descs, const Indices &dst_slice_origin_idxs)
Definition threadwise_tensor_slice_transfer_v7r3_scatter.hpp:126
ck::ThreadwiseTensorSliceTransfer_v7r3_scatter< SrcDatas, DstDatas, SrcDescs, DstDescs, ElementwiseOperation, DstInMemOps, decltype(thread_slice_lengths), SrcDimAccessOrder, DstDimAccessOrder, SrcVectorDim, DstVectorDim, SrcScalarPerVectors, DstScalarPerVector, ThreadTransferSrcResetCoordinateAfterRunFlags, ThreadTransferDstResetCoordinateAfterRunFlags, IndexType, ScatterDim, OutputScatter, ScatterWeightIdx, NumThreadScratch >::SrcCoords decltype(MakeCoordinates(SrcDescs{}, StaticallyIndexedArray< Index, nSrc >{})) SrcCoords
Definition threadwise_tensor_slice_transfer_v7r3_scatter.hpp:78
ck::ThreadwiseTensorSliceTransfer_v7r3_scatter< SrcDatas, DstDatas, SrcDescs, DstDescs, ElementwiseOperation, DstInMemOps, decltype(thread_slice_lengths), SrcDimAccessOrder, DstDimAccessOrder, SrcVectorDim, DstVectorDim, SrcScalarPerVectors, DstScalarPerVector, ThreadTransferSrcResetCoordinateAfterRunFlags, ThreadTransferDstResetCoordinateAfterRunFlags, IndexType, ScatterDim, OutputScatter, ScatterWeightIdx, NumThreadScratch >::I0 static constexpr auto I0
Definition threadwise_tensor_slice_transfer_v7r3_scatter.hpp:53
ck::ThreadwiseTensorSliceTransfer_v7r3_scatter< SrcDatas, DstDatas, SrcDescs, DstDescs, ElementwiseOperation, DstInMemOps, decltype(thread_slice_lengths), SrcDimAccessOrder, DstDimAccessOrder, SrcVectorDim, DstVectorDim, SrcScalarPerVectors, DstScalarPerVector, ThreadTransferSrcResetCoordinateAfterRunFlags, ThreadTransferDstResetCoordinateAfterRunFlags, IndexType, ScatterDim, OutputScatter, ScatterWeightIdx, NumThreadScratch >::nSrc static constexpr index_t nSrc
Definition threadwise_tensor_slice_transfer_v7r3_scatter.hpp:62
__device__ void SetSrcSliceOrigins(const SrcDescs &src_descs, const Indices &src_slice_origin_idxs)
Definition threadwise_tensor_slice_transfer_v7r3_scatter.hpp:117
__device__ void TransposeFromElmToDst(Number< ThreadScratchId > thread_scratch_id=Number< ThreadScratchId >{})
Definition threadwise_tensor_slice_transfer_v7r3_scatter.hpp:287
__device__ constexpr ThreadwiseTensorSliceTransfer_v7r3_scatter(const SrcDescs &src_descs, const StaticallyIndexedArray< Index, nSrc > &src_slice_origins, const DstDescs &dst_descs, const StaticallyIndexedArray< Index, nDst > &dst_slice_origins, const ElementwiseOperation &element_op)
Definition threadwise_tensor_slice_transfer_v7r3_scatter.hpp:99
ck::ThreadwiseTensorSliceTransfer_v7r3_scatter< SrcDatas, DstDatas, SrcDescs, DstDescs, ElementwiseOperation, DstInMemOps, decltype(thread_slice_lengths), SrcDimAccessOrder, DstDimAccessOrder, SrcVectorDim, DstVectorDim, SrcScalarPerVectors, DstScalarPerVector, ThreadTransferSrcResetCoordinateAfterRunFlags, ThreadTransferDstResetCoordinateAfterRunFlags, IndexType, ScatterDim, OutputScatter, ScatterWeightIdx, NumThreadScratch >::I2 static constexpr auto I2
Definition threadwise_tensor_slice_transfer_v7r3_scatter.hpp:55
ck::ThreadwiseTensorSliceTransfer_v7r3_scatter< SrcDatas, DstDatas, SrcDescs, DstDescs, ElementwiseOperation, DstInMemOps, decltype(thread_slice_lengths), SrcDimAccessOrder, DstDimAccessOrder, SrcVectorDim, DstVectorDim, SrcScalarPerVectors, DstScalarPerVector, ThreadTransferSrcResetCoordinateAfterRunFlags, ThreadTransferDstResetCoordinateAfterRunFlags, IndexType, ScatterDim, OutputScatter, ScatterWeightIdx, NumThreadScratch >::dst_scalar_per_access static constexpr auto dst_scalar_per_access
Definition threadwise_tensor_slice_transfer_v7r3_scatter.hpp:86
__device__ void MoveSrcSliceWindow(const SrcDescs &src_descs, Number< ISrc > iSrc, const Index &src_slice_origin_step_idx)
Definition threadwise_tensor_slice_transfer_v7r3_scatter.hpp:609
ck::ThreadwiseTensorSliceTransfer_v7r3_scatter< SrcDatas, DstDatas, SrcDescs, DstDescs, ElementwiseOperation, DstInMemOps, decltype(thread_slice_lengths), SrcDimAccessOrder, DstDimAccessOrder, SrcVectorDim, DstVectorDim, SrcScalarPerVectors, DstScalarPerVector, ThreadTransferSrcResetCoordinateAfterRunFlags, ThreadTransferDstResetCoordinateAfterRunFlags, IndexType, ScatterDim, OutputScatter, ScatterWeightIdx, NumThreadScratch >::nDst static constexpr index_t nDst
Definition threadwise_tensor_slice_transfer_v7r3_scatter.hpp:63
ck::ThreadwiseTensorSliceTransfer_v7r3_scatter< SrcDatas, DstDatas, SrcDescs, DstDescs, ElementwiseOperation, DstInMemOps, decltype(thread_slice_lengths), SrcDimAccessOrder, DstDimAccessOrder, SrcVectorDim, DstVectorDim, SrcScalarPerVectors, DstScalarPerVector, ThreadTransferSrcResetCoordinateAfterRunFlags, ThreadTransferDstResetCoordinateAfterRunFlags, IndexType, ScatterDim, OutputScatter, ScatterWeightIdx, NumThreadScratch >::Index MultiIndex< nDim > Index
Definition threadwise_tensor_slice_transfer_v7r3_scatter.hpp:65
__device__ void RunWrite(const DstDescs &dst_descs, DstBuffers dst_bufs, StaticallyIndexedArray< IndexType, scatter_num > &scatter_offsets, Number< ThreadScratchId > thread_scratch_id=Number< ThreadScratchId >{})
Definition threadwise_tensor_slice_transfer_v7r3_scatter.hpp:390
ck::ThreadwiseTensorSliceTransfer_v7r3_scatter< SrcDatas, DstDatas, SrcDescs, DstDescs, ElementwiseOperation, DstInMemOps, decltype(thread_slice_lengths), SrcDimAccessOrder, DstDimAccessOrder, SrcVectorDim, DstVectorDim, SrcScalarPerVectors, DstScalarPerVector, ThreadTransferSrcResetCoordinateAfterRunFlags, ThreadTransferDstResetCoordinateAfterRunFlags, IndexType, ScatterDim, OutputScatter, ScatterWeightIdx, NumThreadScratch >::nDim static constexpr index_t nDim
Definition threadwise_tensor_slice_transfer_v7r3_scatter.hpp:60
static __device__ constexpr auto GetSrcThreadScratchDescriptor()
Definition threadwise_tensor_slice_transfer_v7r3_scatter.hpp:507
__device__ void OOBCheck(Number< ThreadScratchId > thread_scratch_id=Number< ThreadScratchId >{})
Definition threadwise_tensor_slice_transfer_v7r3_scatter.hpp:267
ck::ThreadwiseTensorSliceTransfer_v7r3_scatter< SrcDatas, DstDatas, SrcDescs, DstDescs, ElementwiseOperation, DstInMemOps, decltype(thread_slice_lengths), SrcDimAccessOrder, DstDimAccessOrder, SrcVectorDim, DstVectorDim, SrcScalarPerVectors, DstScalarPerVector, ThreadTransferSrcResetCoordinateAfterRunFlags, ThreadTransferDstResetCoordinateAfterRunFlags, IndexType, ScatterDim, OutputScatter, ScatterWeightIdx, NumThreadScratch >::SrcScalarPerVector static constexpr auto SrcScalarPerVector
Definition threadwise_tensor_slice_transfer_v7r3_scatter.hpp:58
ck::ThreadwiseTensorSliceTransfer_v7r3_scatter< SrcDatas, DstDatas, SrcDescs, DstDescs, ElementwiseOperation, DstInMemOps, decltype(thread_slice_lengths), SrcDimAccessOrder, DstDimAccessOrder, SrcVectorDim, DstVectorDim, SrcScalarPerVectors, DstScalarPerVector, ThreadTransferSrcResetCoordinateAfterRunFlags, ThreadTransferDstResetCoordinateAfterRunFlags, IndexType, ScatterDim, OutputScatter, ScatterWeightIdx, NumThreadScratch >::DstSpaceFillingCurve SpaceFillingCurve< decltype(thread_slice_lengths), DstDimAccessOrder, remove_cv_t< decltype(dst_scalar_per_access)>, false > DstSpaceFillingCurve
Definition threadwise_tensor_slice_transfer_v7r3_scatter.hpp:94
ck::ThreadwiseTensorSliceTransfer_v7r3_scatter< SrcDatas, DstDatas, SrcDescs, DstDescs, ElementwiseOperation, DstInMemOps, decltype(thread_slice_lengths), SrcDimAccessOrder, DstDimAccessOrder, SrcVectorDim, DstVectorDim, SrcScalarPerVectors, DstScalarPerVector, ThreadTransferSrcResetCoordinateAfterRunFlags, ThreadTransferDstResetCoordinateAfterRunFlags, IndexType, ScatterDim, OutputScatter, ScatterWeightIdx, NumThreadScratch >::I1 static constexpr auto I1
Definition threadwise_tensor_slice_transfer_v7r3_scatter.hpp:54
ck::ThreadwiseTensorSliceTransfer_v7r3_scatter< SrcDatas, DstDatas, SrcDescs, DstDescs, ElementwiseOperation, DstInMemOps, decltype(thread_slice_lengths), SrcDimAccessOrder, DstDimAccessOrder, SrcVectorDim, DstVectorDim, SrcScalarPerVectors, DstScalarPerVector, ThreadTransferSrcResetCoordinateAfterRunFlags, ThreadTransferDstResetCoordinateAfterRunFlags, IndexType, ScatterDim, OutputScatter, ScatterWeightIdx, NumThreadScratch >::I3 static constexpr auto I3
Definition threadwise_tensor_slice_transfer_v7r3_scatter.hpp:56
static __device__ constexpr auto GetSrcCoordinateResetStep()
Definition threadwise_tensor_slice_transfer_v7r3_scatter.hpp:472
static __device__ auto generate_vectors()
Definition threadwise_tensor_slice_transfer_v7r3_scatter.hpp:135
static __device__ constexpr auto GetDstCoordinateResetStep()
Definition threadwise_tensor_slice_transfer_v7r3_scatter.hpp:484
static constexpr auto MakeCoordinates(const Descs &descs, const Indices &indices)
Definition threadwise_tensor_slice_transfer_v7r3_scatter.hpp:72
ck::ThreadwiseTensorSliceTransfer_v7r3_scatter< SrcDatas, DstDatas, SrcDescs, DstDescs, ElementwiseOperation, DstInMemOps, decltype(thread_slice_lengths), SrcDimAccessOrder, DstDimAccessOrder, SrcVectorDim, DstVectorDim, SrcScalarPerVectors, DstScalarPerVector, ThreadTransferSrcResetCoordinateAfterRunFlags, ThreadTransferDstResetCoordinateAfterRunFlags, IndexType, ScatterDim, OutputScatter, ScatterWeightIdx, NumThreadScratch >::src_scalar_per_access static constexpr auto src_scalar_per_access
Definition threadwise_tensor_slice_transfer_v7r3_scatter.hpp:83
__device__ void RunRead(const SrcDescs &src_descs, const SrcBuffers &src_bufs, Number< ThreadScratchId > thread_scratch_id=Number< ThreadScratchId >{})
Definition threadwise_tensor_slice_transfer_v7r3_scatter.hpp:155
ck::ThreadwiseTensorSliceTransfer_v7r3_scatter< SrcDatas, DstDatas, SrcDescs, DstDescs, ElementwiseOperation, DstInMemOps, decltype(thread_slice_lengths), SrcDimAccessOrder, DstDimAccessOrder, SrcVectorDim, DstVectorDim, SrcScalarPerVectors, DstScalarPerVector, ThreadTransferSrcResetCoordinateAfterRunFlags, ThreadTransferDstResetCoordinateAfterRunFlags, IndexType, ScatterDim, OutputScatter, ScatterWeightIdx, NumThreadScratch >::SrcSpaceFillingCurve SpaceFillingCurve< decltype(thread_slice_lengths), SrcDimAccessOrder, remove_cv_t< decltype(src_scalar_per_access)>, false > SrcSpaceFillingCurve
Definition threadwise_tensor_slice_transfer_v7r3_scatter.hpp:89
static __device__ constexpr auto GetDstThreadScratchDescriptor()
Definition threadwise_tensor_slice_transfer_v7r3_scatter.hpp:557
__device__ void MoveDstSliceWindow(const DstDescs &dst_descs, Number< IDst > iDst, const Index &dst_slice_origin_step_idx)
Definition threadwise_tensor_slice_transfer_v7r3_scatter.hpp:627
ck::ThreadwiseTensorSliceTransfer_v7r3_scatter< SrcDatas, DstDatas, SrcDescs, DstDescs, ElementwiseOperation, DstInMemOps, decltype(thread_slice_lengths), SrcDimAccessOrder, DstDimAccessOrder, SrcVectorDim, DstVectorDim, SrcScalarPerVectors, DstScalarPerVector, ThreadTransferSrcResetCoordinateAfterRunFlags, ThreadTransferDstResetCoordinateAfterRunFlags, IndexType, ScatterDim, OutputScatter, ScatterWeightIdx, NumThreadScratch >::DstCoords decltype(MakeCoordinates(DstDescs{}, StaticallyIndexedArray< Index, nDst >{})) DstCoords
Definition threadwise_tensor_slice_transfer_v7r3_scatter.hpp:79
__device__ void Run(const SrcDescs &src_descs, const SrcBuffers &src_bufs, const DstDescs &dst_descs, DstBuffers dst_bufs, StaticallyIndexedArray< IndexType, scatter_num > &scatter_offsets)
Definition threadwise_tensor_slice_transfer_v7r3_scatter.hpp:462
Definition threadwise_tensor_slice_transfer_util.hpp:20
Definition functional2.hpp:33