slice_tile.hpp Source File

slice_tile.hpp Source File#

Composable Kernel: slice_tile.hpp Source File
slice_tile.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
15
16namespace ck_tile {
17
18template <typename BottomTensorView_,
19 typename WindowLengths_,
20 index_t... SliceBegins,
21 index_t... SliceEnds>
22CK_TILE_DEVICE constexpr auto
24 sequence<SliceBegins...> slice_begins,
25 sequence<SliceEnds...> slice_ends)
26{
28 // NOTE: This API will override the origin of the tile window!
29 static_assert(sizeof...(SliceBegins) == sizeof...(SliceEnds));
30 static_assert(sizeof...(SliceBegins) == TileWindow::get_num_of_dimension());
31
32 constexpr auto slice_lengths = slice_ends - slice_begins;
33
35 sequence_to_tuple_of_number(slice_lengths),
36 to_multi_index(slice_begins));
37}
38
39template <typename DataType_,
40 typename StaticTileDistribution_,
41 index_t... SliceBegins,
42 index_t... SliceEnds>
43CK_TILE_DEVICE constexpr auto
45 sequence<SliceBegins...> slice_begins,
46 sequence<SliceEnds...> slice_ends)
47{
48 using DataType = remove_cvref_t<DataType_>;
49 using Distribution = remove_cvref_t<StaticTileDistribution_>;
50
51 constexpr auto sliced_dstr_yidx_ylen =
52 detail::slice_distribution_from_x(Distribution{}, slice_begins, slice_ends);
53
54 constexpr auto sliced_dstr = sliced_dstr_yidx_ylen.template at<0>();
55 constexpr auto sliced_y_origins = sliced_dstr_yidx_ylen.template at<1>();
56 constexpr auto sliced_y_lengths = sliced_dstr_yidx_ylen.template at<2>();
57
58 auto sliced_tensor = make_static_distributed_tensor<DataType>(sliced_dstr);
59
60 sliced_tensor.get_thread_buffer() =
61 tile.get_y_sliced_thread_data(sliced_y_origins, sliced_y_lengths);
62
63 return sliced_tensor;
64}
65
66template <typename DstDataType_,
67 typename DstStaticTileDistribution_,
68 typename SrcDataType_,
69 typename SrcStaticTileDistribution_,
70 index_t... SliceBegins,
71 index_t... SliceEnds>
72CK_TILE_DEVICE constexpr auto
75 sequence<SliceBegins...> slice_begins,
76 sequence<SliceEnds...> slice_ends)
77{
78 using DstDistribution = remove_cvref_t<DstStaticTileDistribution_>;
79 using SrcDistribution = remove_cvref_t<SrcStaticTileDistribution_>;
80
81 constexpr auto sliced_dstr_yidx_ylen =
82 detail::slice_distribution_from_x(DstDistribution{}, slice_begins, slice_ends);
83
84 constexpr auto sliced_dstr = sliced_dstr_yidx_ylen.template at<0>();
85 constexpr auto sliced_y_origins = sliced_dstr_yidx_ylen.template at<1>();
86 constexpr auto sliced_y_lengths = sliced_dstr_yidx_ylen.template at<2>();
87
88 static_assert(std::is_same_v<remove_cvref_t<decltype(sliced_dstr)>, SrcDistribution>, "wrong!");
89
91 sliced_y_origins, sliced_y_lengths, src_tile.get_thread_buffer());
92}
93
94} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
CK_TILE_HOST_DEVICE constexpr auto slice_distribution_from_x(Distribution, sequence< XSliceBegins... > x_slice_begins, sequence< XSliceEnds... > x_slice_ends)
Definition tile_distribution.hpp:554
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_DEVICE constexpr auto get_slice_tile(const tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile, sequence< SliceBegins... > slice_begins, sequence< SliceEnds... > slice_ends)
Definition slice_tile.hpp:23
CK_TILE_HOST_DEVICE constexpr auto to_multi_index(const T &x)
Definition tile/core/container/multi_index.hpp:33
CK_TILE_HOST_DEVICE constexpr auto sequence_to_tuple_of_number(sequence< Is... >)
Definition tile/core/container/container_helper.hpp:459
CK_TILE_DEVICE constexpr auto set_slice_tile(static_distributed_tensor< DstDataType_, DstStaticTileDistribution_ > &dst_tile, const static_distributed_tensor< SrcDataType_, SrcStaticTileDistribution_ > &src_tile, sequence< SliceBegins... > slice_begins, sequence< SliceEnds... > slice_ends)
Definition slice_tile.hpp:73
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution &)
Definition static_distributed_tensor.hpp:142
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
int32_t index_t
Definition integer.hpp:9
Definition tile/core/container/sequence.hpp:49
Definition static_distributed_tensor.hpp:21
CK_TILE_HOST_DEVICE void set_y_sliced_thread_data(sequence< YSliceOrigins... >, sequence< YSliceLengths... >, const SlicedThreadData &sliced_thread_data)
Definition static_distributed_tensor.hpp:93
CK_TILE_HOST_DEVICE auto get_y_sliced_thread_data(sequence< YSliceOrigins... >, sequence< YSliceLengths... >) const
Definition static_distributed_tensor.hpp:68
CK_TILE_HOST_DEVICE constexpr const auto & get_thread_buffer() const
Definition static_distributed_tensor.hpp:58
CK_TILE_DEVICE constexpr auto get_bottom_tensor_view() const
Definition tile_window_base.hpp:47
This class provides description of tile windowed view on the device memory.
Definition tile_window.hpp:1016