static_buffer.hpp Source File

static_buffer.hpp Source File#

Composable Kernel: static_buffer.hpp Source File
static_buffer.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
7
8namespace ck {
9
10// static buffer for scalar
11template <AddressSpaceEnum AddressSpace,
12 typename T,
13 index_t N,
14 bool InvalidElementUseNumericalZeroValue> // TODO remove this bool, no longer needed
16{
17 using type = T;
19
20 __host__ __device__ constexpr StaticBuffer() : base{} {}
21
22 template <typename... Ys>
23 __host__ __device__ constexpr StaticBuffer& operator=(const Tuple<Ys...>& y)
24 {
25 static_assert(base::Size() == sizeof...(Ys), "wrong! size not the same");
26 StaticBuffer& x = *this;
27 static_for<0, base::Size(), 1>{}([&](auto i) { x(i) = y[i]; });
28 return x;
29 }
30
31 __host__ __device__ constexpr StaticBuffer& operator=(const T& y)
32 {
33 StaticBuffer& x = *this;
34 static_for<0, base::Size(), 1>{}([&](auto i) { x(i) = y; });
35 return x;
36 }
37
38 __host__ __device__ static constexpr AddressSpaceEnum GetAddressSpace() { return AddressSpace; }
39
40 __host__ __device__ static constexpr bool IsStaticBuffer() { return true; }
41
42 __host__ __device__ static constexpr bool IsDynamicBuffer() { return false; }
43
44 // read access
45 template <index_t I>
46 __host__ __device__ constexpr const T& operator[](Number<I> i) const
47 {
48 return base::operator[](i);
49 }
50
51 // write access
52 template <index_t I>
53 __host__ __device__ constexpr T& operator()(Number<I> i)
54 {
55 return base::operator()(i);
56 }
57
58 __host__ __device__ void Set(T x)
59 {
60 static_for<0, N, 1>{}([&](auto i) { operator()(i) = T{x}; });
61 }
62
63 __host__ __device__ void Clear() { Set(T{0}); }
64};
65
66// static buffer for vector
67template <AddressSpaceEnum AddressSpace,
68 typename S,
69 index_t NumOfVector,
70 index_t ScalarPerVector,
71 bool InvalidElementUseNumericalZeroValue, // TODO remove this bool, no longer needed,
72 typename enable_if<is_scalar_type<S>::value, bool>::type = false>
74 : public StaticallyIndexedArray<vector_type<S, ScalarPerVector>, NumOfVector>
75{
78
79 static constexpr auto s_per_v = Number<ScalarPerVector>{};
80 static constexpr auto num_of_v_ = Number<NumOfVector>{};
81 static constexpr auto s_per_buf = s_per_v * num_of_v_;
82
83 __host__ __device__ constexpr StaticBufferTupleOfVector() : base{} {}
84
85 __host__ __device__ static constexpr AddressSpaceEnum GetAddressSpace() { return AddressSpace; }
86
87 __host__ __device__ static constexpr bool IsStaticBuffer() { return true; }
88
89 __host__ __device__ static constexpr bool IsDynamicBuffer() { return false; }
90
91 __host__ __device__ static constexpr index_t Size() { return s_per_buf; };
92
93 // Get S
94 // i is offset of S
95 template <index_t I>
96 __host__ __device__ constexpr const S& operator[](Number<I> i) const
97 {
98 constexpr auto i_v = i / s_per_v;
99 constexpr auto i_s = i % s_per_v;
100
101 return base::operator[](i_v).template AsType<S>()[i_s];
102 }
103
104 // Set S
105 // i is offset of S
106 template <index_t I>
107 __host__ __device__ constexpr S& operator()(Number<I> i)
108 {
109 constexpr auto i_v = i / s_per_v;
110 constexpr auto i_s = i % s_per_v;
111
112 return base::operator()(i_v).template AsType<S>()(i_s);
113 }
114
115 // Get X
116 // i is offset of S, not X. i should be aligned to X
117 template <typename X,
118 index_t I,
120 bool>::type = false>
121 __host__ __device__ constexpr auto GetAsType(Number<I> i) const
122 {
123 constexpr auto s_per_x = Number<scalar_type<remove_cvref_t<X>>::vector_size>{};
124
125 static_assert(s_per_v % s_per_x == 0, "wrong! V must one or multiple X");
126 static_assert(i % s_per_x == 0, "wrong!");
127
128 constexpr auto i_v = i / s_per_v;
129 constexpr auto i_x = (i % s_per_v) / s_per_x;
130
131 return base::operator[](i_v).template AsType<X>()[i_x];
132 }
133
134 // Set X
135 // i is offset of S, not X. i should be aligned to X
136 template <typename X,
137 index_t I,
139 bool>::type = false>
140 __host__ __device__ constexpr void SetAsType(Number<I> i, X x)
141 {
142 constexpr auto s_per_x = Number<scalar_type<remove_cvref_t<X>>::vector_size>{};
143
144 static_assert(s_per_v % s_per_x == 0, "wrong! V must contain one or multiple X");
145 static_assert(i % s_per_x == 0, "wrong!");
146
147 constexpr auto i_v = i / s_per_v;
148 constexpr auto i_x = (i % s_per_v) / s_per_x;
149
150 base::operator()(i_v).template AsType<X>()(i_x) = x;
151 }
152
153 // Get read access to vector_type V
154 // i is offset of S, not V. i should be aligned to V
155 template <index_t I>
156 __host__ __device__ constexpr const auto& GetVectorTypeReference(Number<I> i) const
157 {
158 static_assert(i % s_per_v == 0, "wrong!");
159
160 constexpr auto i_v = i / s_per_v;
161
162 return base::operator[](i_v);
163 }
164
165 // Get write access to vector_type V
166 // i is offset of S, not V. i should be aligned to V
167 template <index_t I>
168 __host__ __device__ constexpr auto& GetVectorTypeReference(Number<I> i)
169 {
170 static_assert(i % s_per_v == 0, "wrong!");
171
172 constexpr auto i_v = i / s_per_v;
173
174 return base::operator()(i_v);
175 }
176
177 __host__ __device__ void Clear()
178 {
179 constexpr index_t NumScalars = NumOfVector * ScalarPerVector;
180
181 static_for<0, NumScalars, 1>{}([&](auto i) { SetAsType(i, S{0}); });
182 }
183};
184
185template <AddressSpaceEnum AddressSpace, typename T, index_t N>
186__host__ __device__ constexpr auto make_static_buffer(Number<N>)
187{
189}
190
191template <AddressSpaceEnum AddressSpace, typename T, long_index_t N>
192__host__ __device__ constexpr auto make_static_buffer(LongNumber<N>)
193{
195}
196
197} // namespace ck
Definition ck.hpp:268
integral_constant< long_index_t, N > LongNumber
Definition number.hpp:15
__host__ __device__ constexpr auto make_static_buffer(Number< N >)
Definition static_buffer.hpp:186
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition utility/statically_indexed_array.hpp:45
int32_t index_t
Definition ck.hpp:299
constexpr bool is_native_type()
Definition data_type.hpp:203
integral_constant< index_t, N > Number
Definition number.hpp:12
AddressSpaceEnum
Definition amd_address_space.hpp:15
std::enable_if< B, T > enable_if
Definition enable_if.hpp:24
Definition static_buffer.hpp:16
__host__ __device__ constexpr StaticBuffer()
Definition static_buffer.hpp:20
__host__ __device__ void Clear()
Definition static_buffer.hpp:63
__host__ static __device__ constexpr bool IsDynamicBuffer()
Definition static_buffer.hpp:42
__host__ __device__ constexpr StaticBuffer & operator=(const T &y)
Definition static_buffer.hpp:31
__host__ __device__ constexpr StaticBuffer & operator=(const Tuple< Ys... > &y)
Definition static_buffer.hpp:23
StaticallyIndexedArray< AccDataType, N > base
Definition static_buffer.hpp:18
__host__ __device__ constexpr T & operator()(Number< I > i)
Definition static_buffer.hpp:53
__host__ __device__ constexpr const T & operator[](Number< I > i) const
Definition static_buffer.hpp:46
__host__ static __device__ constexpr bool IsStaticBuffer()
Definition static_buffer.hpp:40
__host__ __device__ void Set(T x)
Definition static_buffer.hpp:58
__host__ static __device__ constexpr AddressSpaceEnum GetAddressSpace()
Definition static_buffer.hpp:38
__host__ static __device__ constexpr bool IsDynamicBuffer()
Definition static_buffer.hpp:89
__host__ __device__ constexpr void SetAsType(Number< I > i, X x)
Definition static_buffer.hpp:140
__host__ static __device__ constexpr index_t Size()
Definition static_buffer.hpp:91
__host__ __device__ void Clear()
Definition static_buffer.hpp:177
StaticallyIndexedArray< vector_type< S, ScalarPerVector >, NumOfVector > base
Definition static_buffer.hpp:77
__host__ __device__ constexpr auto & GetVectorTypeReference(Number< I > i)
Definition static_buffer.hpp:168
__host__ __device__ constexpr const S & operator[](Number< I > i) const
Definition static_buffer.hpp:96
__host__ static __device__ constexpr bool IsStaticBuffer()
Definition static_buffer.hpp:87
__host__ __device__ constexpr auto GetAsType(Number< I > i) const
Definition static_buffer.hpp:121
__host__ __device__ constexpr S & operator()(Number< I > i)
Definition static_buffer.hpp:107
__host__ __device__ constexpr const auto & GetVectorTypeReference(Number< I > i) const
Definition static_buffer.hpp:156
typename vector_type< S, ScalarPerVector >::type V
Definition static_buffer.hpp:76
static constexpr auto s_per_buf
Definition static_buffer.hpp:81
__host__ static __device__ constexpr AddressSpaceEnum GetAddressSpace()
Definition static_buffer.hpp:85
__host__ __device__ constexpr StaticBufferTupleOfVector()
Definition static_buffer.hpp:83
static constexpr auto s_per_v
Definition static_buffer.hpp:79
static constexpr auto num_of_v_
Definition static_buffer.hpp:80
Definition utility/tuple.hpp:117
Definition functional2.hpp:33
Definition dtype_vector.hpp:10