DeviceNormalizationFwd< XDataType, GammaDataType, BetaDataType, YDataType, SaveMeanInvStdDataType, YElementwiseOperation, Rank, NumReduceDim > Struct Template Reference

DeviceNormalizationFwd&lt; XDataType, GammaDataType, BetaDataType, YDataType, SaveMeanInvStdDataType, YElementwiseOperation, Rank, NumReduceDim &gt; Struct Template Reference#

Composable Kernel: ck::tensor_operation::device::DeviceNormalizationFwd< XDataType, GammaDataType, BetaDataType, YDataType, SaveMeanInvStdDataType, YElementwiseOperation, Rank, NumReduceDim > Struct Template Reference
ck::tensor_operation::device::DeviceNormalizationFwd< XDataType, GammaDataType, BetaDataType, YDataType, SaveMeanInvStdDataType, YElementwiseOperation, Rank, NumReduceDim > Struct Template Referenceabstract

#include <device_normalization_fwd.hpp>

Inheritance diagram for ck::tensor_operation::device::DeviceNormalizationFwd< XDataType, GammaDataType, BetaDataType, YDataType, SaveMeanInvStdDataType, YElementwiseOperation, Rank, NumReduceDim >:
ck::tensor_operation::device::BaseOperator ck::tensor_operation::device::DeviceNormalizationFwdImpl< XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, SaveMeanInvStdDataType, YElementwiseOperation, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize, SaveMeanInvStdDstVectorSize, UseWelford > ck::tensor_operation::device::DeviceNormalizationFwdSplitKImpl< XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, SaveMeanInvStdDataType, YElementwiseOperation, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize, SaveMeanInvStdDstVectorSize >

Public Member Functions

virtual std::unique_ptr< BaseArgumentMakeArgumentPointer (const std::vector< index_t > lengths, const std::vector< index_t > xStrides, const std::vector< index_t > gammaStrides, const std::vector< index_t > betaStrides, const std::vector< index_t > yStrides, const std::vector< index_t > saveMeanStrides, const std::vector< index_t > saveInvStdStrides, const std::vector< index_t > reduceDims, double epsilon, const void *p_x, const void *p_gamma, const void *p_beta, void *p_y, void *p_savedMean, void *p_savedInvVar, YElementwiseOperation y_elementwise_op)=0
virtual std::unique_ptr< BaseInvokerMakeInvokerPointer ()=0
Public Member Functions inherited from ck::tensor_operation::device::BaseOperator
 BaseOperator ()=default
 BaseOperator (const BaseOperator &)=default
BaseOperatoroperator= (const BaseOperator &)=default
virtual bool IsSupportedArgument (const BaseArgument *)
virtual std::string GetTypeString () const
virtual std::string GetInstanceString () const
virtual std::string GetTypeIdName () const
virtual std::optional< std::string > GetObjectName () const
virtual std::optional< std::string > GetTemplateInfo () const
virtual std::string GetTypeIdHashCode () const
virtual size_t GetWorkSpaceSize (const BaseArgument *) const
virtual void SetWorkSpacePointer (BaseArgument *p_arg, void *p_workspace, const StreamConfig &=StreamConfig{}) const
virtual ~BaseOperator ()

Member Function Documentation

◆ MakeArgumentPointer()

template<typename XDataType, typename GammaDataType, typename BetaDataType, typename YDataType, typename SaveMeanInvStdDataType, typename YElementwiseOperation, index_t Rank, index_t NumReduceDim>
virtual std::unique_ptr< BaseArgument > ck::tensor_operation::device::DeviceNormalizationFwd< XDataType, GammaDataType, BetaDataType, YDataType, SaveMeanInvStdDataType, YElementwiseOperation, Rank, NumReduceDim >::MakeArgumentPointer ( const std::vector< index_t > lengths,
const std::vector< index_t > xStrides,
const std::vector< index_t > gammaStrides,
const std::vector< index_t > betaStrides,
const std::vector< index_t > yStrides,
const std::vector< index_t > saveMeanStrides,
const std::vector< index_t > saveInvStdStrides,
const std::vector< index_t > reduceDims,
double epsilon,
const void * p_x,
const void * p_gamma,
const void * p_beta,
void * p_y,
void * p_savedMean,
void * p_savedInvVar,
YElementwiseOperation y_elementwise_op )
pure virtual

◆ MakeInvokerPointer()


The documentation for this struct was generated from the following file: