thread_welford.hpp Source File

thread_welford.hpp Source File#

Composable Kernel: thread_welford.hpp Source File
thread_welford.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
7
8namespace ck_tile {
9
10template <typename T, bool kFastFDiv = false>
11CK_TILE_DEVICE void welford_update(T& mean, T& var, T x, int count, bool_constant<kFastFDiv> = {})
12{
13 // TODO: check nan? maybe no
14 T delta = x - mean;
15 if(kFastFDiv && std::is_same_v<T, float>)
16 {
17 mean += delta * __builtin_amdgcn_rcpf(count);
18 }
19 else
20 {
21 mean += delta / count;
22 }
23 T delta2 = x - mean;
24 var += delta * delta2;
25}
26
27template <typename T, bool kFastFDiv = false>
28CK_TILE_DEVICE static void welford_merge(T& mean_a,
29 T& var_a,
30 int& count_a,
31 T mean_b,
32 T var_b,
33 int count_b,
35{
36 int count = count_a + count_b;
37 T count_ = type_convert<T>(count);
38 T count_a_ = type_convert<T>(count_a);
39 T count_b_ = type_convert<T>(count_b);
40 T count_b_over_count;
41 if(kFastFDiv && std::is_same_v<T, float>)
42 {
43 count_b_over_count =
44 count == 0 ? type_convert<T>(0) : count_b_ * __builtin_amdgcn_rcpf(count_);
45 }
46 else
47 {
48 count_b_over_count = count == 0 ? type_convert<T>(0) : count_b_ / count_;
49 }
50
51 T delta = mean_b - mean_a;
52 mean_a += delta * count_b_over_count;
53 var_a += var_b + delta * delta * count_a_ * count_b_over_count;
54 count_a = count;
55}
56
57} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
Definition tile/core/algorithm/cluster_descriptor.hpp:13
constant< b > bool_constant
Definition tile/core/numeric/integral_constant.hpp:43
CK_TILE_DEVICE void welford_update(T &mean, T &var, T x, int count, bool_constant< kFastFDiv >={})
Definition thread_welford.hpp:11
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
Definition tile/core/numeric/type_convert.hpp:29