thread_group_tensor_slice_transfer_v7r3.hpp Source File

thread_group_tensor_slice_transfer_v7r3.hpp Source File#

Composable Kernel: thread_group_tensor_slice_transfer_v7r3.hpp Source File
thread_group_tensor_slice_transfer_v7r3.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
12
13namespace ck {
14
15// Thread-group level multi-source, multi-destination tensor slice data movement
16// Assume:
17// 1. All sources and destinations are DynamicBuffer
18// 2. Same VectorDim and ScalerPerVector for all sources and destinations
19// 3. DstInMemOps are per destination tensor
20// 4. ThreadTransferSrcResetCoordinateAfterRunFlags are per source tensor
21// 5. ThreadTransferDstResetCoordinateAfterRunFlags are per destination tensor
22//
23// Does following things to avoid scratch memory issue
24// 1. Pass tensor descritpors by reference (or tuple of references)
25// 2. Does not keep reference to tensor descriptor
26// 3. Does not construct new tensor coordinate when call Run()
27template <typename ThreadGroup,
28 typename SrcDatas,
29 typename DstDatas,
30 typename SrcDescs,
31 typename DstDescs,
32 typename ElementwiseOperation,
33 typename DstInMemOps, // Sequence<InMemoryDataOperationEnum ...>
34 typename SliceLengths,
35 typename ThreadClusterLengths,
36 typename ThreadClusterArrangeOrder,
37 typename SrcDimAccessOrder,
38 typename DstDimAccessOrder,
39 index_t SrcVectorDim,
40 index_t DstVectorDim,
41 typename SrcScalarPerVectors,
42 index_t DstScalarPerVector,
43 typename ThreadTransferSrcResetCoordinateAfterRunFlags,
44 typename ThreadTransferDstResetCoordinateAfterRunFlags,
45 index_t NumThreadScratch = 1,
46 typename InterDatas = DstDatas>
48{
49 static constexpr index_t nDim =
51
54
56
57 static constexpr auto thread_slice_lengths = SliceLengths{} / ThreadClusterLengths{};
58
60 const SrcDescs& src_descs,
61 const StaticallyIndexedArray<Index, nSrc>& src_block_slice_origins,
62 const DstDescs& dst_descs,
63 const StaticallyIndexedArray<Index, nDst>& dst_block_slice_origins,
64 const ElementwiseOperation& element_op)
65 : threadwise_transfer_(src_descs,
67 dst_descs,
69 element_op)
70 {
71 static_assert(nSrc == SrcDatas::Size() && nSrc == SrcDescs::Size() &&
72 nSrc == ThreadTransferSrcResetCoordinateAfterRunFlags::Size() &&
73 nDst == DstDatas::Size() && nDst == DstDescs::Size() &&
74 nDst == ThreadTransferDstResetCoordinateAfterRunFlags::Size(),
75 "wrong!");
76
77 static_for<0, nSrc, 1>{}([&](auto i) {
78 static_assert(
79 nDim == remove_cvref_t<tuple_element_t<i.value, SrcDescs>>::GetNumOfDimension(),
80 "wrong!");
81 });
82
83 static_for<0, nDst, 1>{}([&](auto i) {
84 static_assert(
85 nDim == remove_cvref_t<tuple_element_t<i.value, DstDescs>>::GetNumOfDimension(),
86 "wrong!");
87 });
88
89 static_assert(nDim == ThreadClusterLengths::Size() &&
90 nDim == ThreadClusterArrangeOrder::Size() &&
91 nDim == SrcDimAccessOrder::Size() && nDim == DstDimAccessOrder::Size(),
92 "wrong! nDim not consistent");
93
94 static_assert(
95 is_same<SliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{},
96 "wrong! threads should be mapped to cover entire slicing window");
97
98 static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(),
99 "wrong! ThreadGroup::GetNumOfThread() too small");
100
101 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() ||
102 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
103 {
104 const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
105 make_multi_index(ThreadGroup::GetThreadId()));
106
107 const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
108
109 const auto src_thread_slice_origins = generate_tuple(
110 [&](auto i) { return src_block_slice_origins[i] + thread_data_idx_begin; },
111 Number<nSrc>{});
112
113 const auto dst_thread_slice_origins = generate_tuple(
114 [&](auto i) { return dst_block_slice_origins[i] + thread_data_idx_begin; },
115 Number<nDst>{});
116
117 threadwise_transfer_.SetSrcSliceOrigins(src_descs, src_thread_slice_origins);
118 threadwise_transfer_.SetDstSliceOrigins(dst_descs, dst_thread_slice_origins);
119 }
120 }
121
122 template <typename SrcBuffers, index_t ThreadScratchId = 0>
123 __device__ void RunRead(const SrcDescs& src_descs,
124 const SrcBuffers& src_bufs,
126 {
127 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() ||
128 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
129 {
130 threadwise_transfer_.RunRead(src_descs, src_bufs, thread_scratch_id);
131 }
132 }
133
134 template <typename T>
135 using is_tuple = decltype(std::declval<T&>().IsTuple());
136
137 template <typename DstBuffers, index_t ThreadScratchId = 0>
138 __device__ void RunWrite(const DstDescs& dst_descs,
139 DstBuffers dst_bufs,
141 {
142 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() ||
143 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
144 {
145 if constexpr(is_detected<is_tuple, decltype(dst_bufs)>::value)
146 threadwise_transfer_.RunWrite(dst_descs, dst_bufs, thread_scratch_id);
147 else
148 threadwise_transfer_.RunWrite(dst_descs, tie(dst_bufs), thread_scratch_id);
149 }
150 }
151
152 template <typename DstBuffers,
153 typename DstVgprDescs,
154 typename DstVgprBuffers,
155 index_t ThreadScratchId = 0>
156 __device__ void
157 RunWriteAndStoreVgpr(const DstDescs& dst_descs,
158 DstBuffers dst_bufs,
159 const DstVgprDescs& dst_vgpr_desc,
160 DstVgprBuffers dst_vgpr_buf,
162 {
163 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() ||
164 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
165 {
166 if constexpr(is_detected<is_tuple, decltype(dst_bufs)>::value &&
167 is_detected<is_tuple, decltype(dst_vgpr_buf)>::value)
168 threadwise_transfer_.RunWriteAndStoreVgpr(
169 dst_descs, dst_bufs, dst_vgpr_desc, dst_vgpr_buf, thread_scratch_id);
170 else if constexpr(is_detected<is_tuple, decltype(dst_bufs)>::value)
171 threadwise_transfer_.RunWriteAndStoreVgpr(
172 dst_descs, dst_bufs, dst_vgpr_desc, tie(dst_vgpr_buf), thread_scratch_id);
173 else if constexpr(is_detected<is_tuple, decltype(dst_vgpr_buf)>::value)
174 threadwise_transfer_.RunWriteAndStoreVgpr(
175 dst_descs, tie(dst_bufs), dst_vgpr_desc, dst_vgpr_buf, thread_scratch_id);
176 else
177 threadwise_transfer_.RunWriteAndStoreVgpr(
178 dst_descs, tie(dst_bufs), dst_vgpr_desc, tie(dst_vgpr_buf), thread_scratch_id);
179 }
180 }
181
182 template <typename SrcBuffers, typename DstBuffers>
183 __device__ void Run(const SrcDescs& src_descs,
184 const SrcBuffers& src_bufs,
185 const DstDescs& dst_descs,
186 DstBuffers dst_bufs)
187 {
188 RunRead(src_descs, src_bufs);
189 RunWrite(dst_descs, dst_bufs);
190 }
191
192 template <index_t ISrc>
193 __device__ void
194 MoveSrcSliceWindow(const SrcDescs& src_descs, Number<ISrc> iSrc, const Index& step)
195 {
196 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() ||
197 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
198 {
199 threadwise_transfer_.MoveSrcSliceWindow(src_descs, iSrc, step);
200 }
201 }
202
203 __device__ void MoveSrcSliceWindow(const SrcDescs& src_descs, const Index& step)
204 {
205 static_for<0, SrcDescs::Size(), 1>{}(
206 [&](auto i) { MoveSrcSliceWindow(src_descs, i, step); });
207 }
208
209 template <index_t IDst>
210 __device__ void
211 MoveDstSliceWindow(const DstDescs& dst_descs, Number<IDst> iDst, const Index& step)
212 {
213 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() ||
214 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
215 {
216 threadwise_transfer_.MoveDstSliceWindow(dst_descs, iDst, step);
217 }
218 }
219
220 __device__ void MoveDstSliceWindow(const DstDescs& dst_descs, const Index& step)
221 {
222 static_for<0, DstDescs::Size(), 1>{}(
223 [&](auto i) { MoveDstSliceWindow(dst_descs, i, step); });
224 }
225
226 private:
227 static constexpr auto thread_cluster_desc_ =
228 make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
229
230 using ThreadwiseTransfer =
231 ThreadwiseTensorSliceTransfer_v7r3<SrcDatas,
232 DstDatas,
233 SrcDescs,
234 DstDescs,
235 ElementwiseOperation,
236 DstInMemOps,
237 decltype(thread_slice_lengths),
238 SrcDimAccessOrder,
239 DstDimAccessOrder,
240 SrcVectorDim,
241 DstVectorDim,
242 SrcScalarPerVectors,
243 DstScalarPerVector,
244 ThreadTransferSrcResetCoordinateAfterRunFlags,
245 ThreadTransferDstResetCoordinateAfterRunFlags,
246 NumThreadScratch,
247 InterDatas>;
248
249 ThreadwiseTransfer threadwise_transfer_;
250};
251
252} // namespace ck
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition utility/statically_indexed_array.hpp:45
int32_t index_t
Definition ck.hpp:299
decltype(ck::declval< T & >().IsTuple()) is_tuple
Definition tuple_helper.hpp:176
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition utility/tuple.hpp:218
__host__ __device__ constexpr auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition tensor_description/cluster_descriptor.hpp:13
typename detail::detector< nonesuch, void, Op, Args... >::value_t is_detected
Definition is_detected.hpp:34
integral_constant< index_t, N > Number
Definition number.hpp:12
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
Array< index_t, N > MultiIndex
Definition array_multi_index.hpp:12
const GenericPointer< typename T::ValueType > T2 value
Definition pointer.h:1697
__device__ void Run(const SrcDescs &src_descs, const SrcBuffers &src_bufs, const DstDescs &dst_descs, DstBuffers dst_bufs)
Definition thread_group_tensor_slice_transfer_v7r3.hpp:183
static constexpr index_t nSrc
Definition thread_group_tensor_slice_transfer_v7r3.hpp:52
__device__ void MoveDstSliceWindow(const DstDescs &dst_descs, Number< IDst > iDst, const Index &step)
Definition thread_group_tensor_slice_transfer_v7r3.hpp:211
static constexpr auto thread_slice_lengths
Definition thread_group_tensor_slice_transfer_v7r3.hpp:57
MultiIndex< nDim > Index
Definition thread_group_tensor_slice_transfer_v7r3.hpp:55
__device__ void MoveDstSliceWindow(const DstDescs &dst_descs, const Index &step)
Definition thread_group_tensor_slice_transfer_v7r3.hpp:220
__device__ constexpr ThreadGroupTensorSliceTransfer_v7r3(const SrcDescs &src_descs, const StaticallyIndexedArray< Index, nSrc > &src_block_slice_origins, const DstDescs &dst_descs, const StaticallyIndexedArray< Index, nDst > &dst_block_slice_origins, const ElementwiseOperation &element_op)
Definition thread_group_tensor_slice_transfer_v7r3.hpp:59
__device__ void MoveSrcSliceWindow(const SrcDescs &src_descs, Number< ISrc > iSrc, const Index &step)
Definition thread_group_tensor_slice_transfer_v7r3.hpp:194
__device__ void RunWriteAndStoreVgpr(const DstDescs &dst_descs, DstBuffers dst_bufs, const DstVgprDescs &dst_vgpr_desc, DstVgprBuffers dst_vgpr_buf, Number< ThreadScratchId > thread_scratch_id=Number< ThreadScratchId >{})
Definition thread_group_tensor_slice_transfer_v7r3.hpp:157
decltype(std::declval< T & >().IsTuple()) is_tuple
Definition thread_group_tensor_slice_transfer_v7r3.hpp:135
__device__ void RunWrite(const DstDescs &dst_descs, DstBuffers dst_bufs, Number< ThreadScratchId > thread_scratch_id=Number< ThreadScratchId >{})
Definition thread_group_tensor_slice_transfer_v7r3.hpp:138
static constexpr index_t nDim
Definition thread_group_tensor_slice_transfer_v7r3.hpp:49
__device__ void RunRead(const SrcDescs &src_descs, const SrcBuffers &src_bufs, Number< ThreadScratchId > thread_scratch_id=Number< ThreadScratchId >{})
Definition thread_group_tensor_slice_transfer_v7r3.hpp:123
__device__ void MoveSrcSliceWindow(const SrcDescs &src_descs, const Index &step)
Definition thread_group_tensor_slice_transfer_v7r3.hpp:203
static constexpr index_t nDst
Definition thread_group_tensor_slice_transfer_v7r3.hpp:53
__device__ void RunRead(const SrcDescs &src_descs, const SrcBuffers &src_bufs, Number< ThreadScratchId > thread_scratch_id=Number< ThreadScratchId >{})
Definition threadwise_tensor_slice_transfer_v7r3.hpp:150
__device__ void RunWrite(const DstDescs &dst_descs, DstBuffers dst_bufs, Number< ThreadScratchId > thread_scratch_id=Number< ThreadScratchId >{})
Definition threadwise_tensor_slice_transfer_v7r3.hpp:501
__device__ void RunWriteAndStoreVgpr(const DstDescs &dst_descs, DstBuffers dst_bufs, const DstVgprDescs &, DstVgprBuffers dst_vgpr_buf, Number< ThreadScratchId > thread_scratch_id=Number< ThreadScratchId >{})
Definition threadwise_tensor_slice_transfer_v7r3.hpp:408
Definition functional2.hpp:33