device_grouped_conv_fwd_multiple_d_multiple_r.hpp Source File

device_grouped_conv_fwd_multiple_d_multiple_r.hpp Source File#

Composable Kernel: device_grouped_conv_fwd_multiple_d_multiple_r.hpp Source File
device_grouped_conv_fwd_multiple_d_multiple_r.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <vector>
7
9
10namespace ck {
11namespace tensor_operation {
12namespace device {
13
14// Grouped Convolution Forward:
15// input : input image A[G, N, C, Hi, Wi],
16// input : weight B[G, K, C, Y, X],
17// input : D0[G, N, K, Ho, Wo], D1[G, N, K, Ho, Wo], ...
18// output : output image E[G, N, K, Ho, Wo]
19// output : R0[G, N, Ho, Wo], R1[G, N, Ho, Wo], ...
20// C = a_op(A) * b_op(B)
21// E = cde_op(C, D0, D1, ...)
22// Q0 = reduce0(q_op0(E)), Q1 = reduce1(q_op0(E)), ...
23// R0 = r_op0(Q0), R1 = r_op1(Q1), ...
24// Assume:
25// D0, D1, ... and E have the same layout
26template <index_t NDimSpatial,
27 typename ALayout,
28 typename BLayout,
29 typename DELayout,
30 typename RLayout,
31 typename ADataType,
32 typename BDataType,
33 typename DsDataType,
34 typename EDataType,
35 typename RsDataType,
36 typename AElementwiseOperation,
37 typename BElementwiseOperation,
38 typename CDEElementwiseOperation,
39 typename QsElementwiseOperation,
40 typename RsElementwiseOperation>
42{
43 static constexpr index_t NumDTensor = DsDataType::Size();
44 static constexpr index_t NumRTensor = RsDataType::Size();
45
46 virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(
47 const void* p_a,
48 const void* p_b,
49 const std::array<const void*, NumDTensor>& p_ds,
50 void* p_e,
51 std::array<void*, NumRTensor> p_rs,
52 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
53 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
54 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
55 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
56 const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_lengths,
57 const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_strides,
58 const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
59 const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
60 const std::array<index_t, NDimSpatial + 2>& r_g_n_wos_lengths,
61 const std::array<index_t, NDimSpatial + 2>& r_g_n_wos_strides,
62 const std::array<index_t, NDimSpatial>& conv_filter_strides,
63 const std::array<index_t, NDimSpatial>& conv_filter_dilations,
64 const std::array<index_t, NDimSpatial>& input_left_pads,
65 const std::array<index_t, NDimSpatial>& input_right_pads,
66 const AElementwiseOperation& a_element_op,
67 const BElementwiseOperation& b_element_op,
68 const CDEElementwiseOperation& cde_element_op,
69 const QsElementwiseOperation& qs_element_op,
70 const RsElementwiseOperation& rs_element_op) = 0;
71
72 virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
73};
74
75} // namespace device
76} // namespace tensor_operation
77} // namespace ck
Definition convolution_backward_data_specialization.hpp:8
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
Definition device_grouped_conv_fwd_multiple_d_multiple_r.hpp:42
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0
static constexpr index_t NumRTensor
Definition device_grouped_conv_fwd_multiple_d_multiple_r.hpp:44
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, const std::array< const void *, NumDTensor > &p_ds, void *p_e, std::array< void *, NumRTensor > p_rs, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_lengths, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_strides, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_lengths, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_strides, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_strides, const std::array< index_t, NDimSpatial+2 > &r_g_n_wos_lengths, const std::array< index_t, NDimSpatial+2 > &r_g_n_wos_strides, const std::array< index_t, NDimSpatial > &conv_filter_strides, const std::array< index_t, NDimSpatial > &conv_filter_dilations, const std::array< index_t, NDimSpatial > &input_left_pads, const std::array< index_t, NDimSpatial > &input_right_pads, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CDEElementwiseOperation &cde_element_op, const QsElementwiseOperation &qs_element_op, const RsElementwiseOperation &rs_element_op)=0
static constexpr index_t NumDTensor
Definition device_grouped_conv_fwd_multiple_d_multiple_r.hpp:43