11 #ifndef MXNET_OPERATOR_UTIL_H_
12 #define MXNET_OPERATOR_UTIL_H_
14 #include <dmlc/registry.h>
15 #include <dmlc/parameter.h>
54 std::vector<std::pair<std::string, std::string> >
kwargs;
256 const std::vector<ResourceRequest>& reqs) = 0;
377 return Get()->fmap_.at(name);
386 std::map<std::string, SimpleOpRegEntry*> fmap_;
397 #define ASSIGN_DISPATCH(out, req, exp) \
403 case kWriteInplace: \
410 LOG(FATAL) << "not reached"; \
419 #define MXNET_RANGE_SWITCH(var, NDIM, ...) \
424 static const int NDIM = 1; \
430 static const int NDIM = 2; \
436 static const int NDIM = 3; \
442 static const int NDIM = 4; \
448 static const int NDIM = 5; \
453 LOG(FATAL) << "Only support ndim=1 to 5."; \
478 #define MXNET_REGISTER_SIMPLE_OP(Name, DEV) \
479 static ::mxnet::op::SimpleOpRegEntry & \
480 __make_ ## SimpleOpRegEntry ## _ ## Name ## __ ## DEV ##__ = \
481 ::mxnet::op::SimpleOpRegistry::Get()->__REGISTER_OR_FIND__(#Name)
485 #endif // MXNET_OPERATOR_UTIL_H_
Definition: operator_util.h:215
Gradient of output value.
Definition: operator_util.h:44
virtual TSelf & set_resource_request(const std::vector< ResourceRequest > &reqs)=0
set resource request By default there is no resource request. The resource will be presented in both ...
SimpleOpRegOption
options in the registry to set symbolic registration
Definition: operator_util.h:213
mshadow::TShape TShape
dynamic shape type
Definition: base.h:85
Definition: operator_util.h:209
TShape(* BinaryShapeFunction)(const TShape &lhs, const TShape &rhs, const EnvArguments &env)
Shape inference function to get the correct shape given source shapes.
Definition: operator_util.h:149
virtual TSelf & set_enable_kwargs(bool enable_kwargs)=0
set whether to enable kwargs A function cannot have both kwargs and scalar arguments. Default: this is set to false
virtual TSelf & set_symbol_op_name(char const *symbol_name)=0
set a seperate name for symbol This must be called before set_function. Default: this is set to be sa...
virtual TSelf & add_arguments(const std::vector< dmlc::ParamFieldInfo > &args)=0
Describe the function.
void(* UnaryGradFunctionT1)(const OutputGrad &out_grad, const OutputValue &out_value, const EnvArguments &env, TBlob *in_grad, OpReqType req, RunContext ctx)
Gradient function that takes output value of function and computes gradient wrt to input...
Definition: operator_util.h:104
registry for TBlob functions
Definition: operator_util.h:363
The resources that can be requested by Operator.
Definition: resource.h:18
mshadow::default_real_t real_t
data type that will be used to store ndarray
Definition: base.h:82
mshadow::TBlob TBlob
storage container type
Definition: base.h:87
std::string name
name of the operator
Definition: operator_util.h:224
in unary forward, allow inplace in with out
Definition: operator_util.h:197
void(* UnaryGradFunctionT2)(const OutputGrad &out_grad, const Input0 &in_data0, const EnvArguments &env, TBlob *in_grad, OpReqType req, RunContext ctx)
Gradient function that takes input value of function and computes gradient wrt to input...
Definition: operator_util.h:119
execution time context. The information needed in runtime for actual execution.
Definition: base.h:181
real_t scalar
scalar argument, if enabled
Definition: operator_util.h:52
static const SimpleOpRegEntry * Find(const std::string &name)
Find the entry with corresponding name.
Definition: operator_util.h:376
virtual ~SimpleOpRegEntry()
virtual destructor
Definition: operator_util.h:359
registry entry to register simple operators via functions.
Definition: operator_util.h:219
Operator interface of mxnet.
virtual TSelf & describe(const std::string &description)=0
Describe the function.
TShape(* UnaryShapeFunction)(const TShape &src, const EnvArguments &env)
Shape inference function to get the correct shape given source.
Definition: operator_util.h:79
in unary backward, allow inplace out_grad with in_grad
Definition: operator_util.h:199
super class of all gradient function argument
Definition: operator_util.h:31
TBlob data
The real data.
Definition: operator_util.h:33
do not allow inplace in arguments
Definition: operator_util.h:195
void(* UnaryFunction)(const TBlob &src, const EnvArguments &env, TBlob *ret, OpReqType req, RunContext ctx)
Unary function that takes a src and save result to ret. The result container is pre-allocated with th...
Definition: operator_util.h:68
SimpleOpInplaceOption
options in the registry to set inplace of operator
Definition: operator_util.h:193
void(* UnaryGradFunctionT0)(const OutputGrad &out_grad, const EnvArguments &env, TBlob *in_grad, OpReqType req, RunContext ctx)
Gradient function that takes output value of function and computes gradient wrt to input...
Definition: operator_util.h:90
virtual TSelf & set_enable_scalar(bool enable_scalar, SimpleOpScalarOption type_mask=kArrayBeforeScalar)=0
set number of scalar arguments needed to be passed in env A function cannot have both kwargs and scal...
void(* BinaryGradFunctionT0)(const OutputGrad &out_grad, const EnvArguments &env, TBlob *lhs_grad, TBlob *rhs_grad, OpReqType req_lhs_grad, OpReqType req_rhs_grad, RunContext ctx)
Gradient function that takes only output gradient and computes gradient wrt to input. We support total gradient as a whole to make it easy to combine a few ops.
Definition: operator_util.h:163
static SimpleOpRegistry * Get()
OpReqType
operation request type to Forward and Backward
Definition: operator.h:23
Environment arguments that is used by the function. These can be things like scalar arguments when ad...
Definition: operator_util.h:50
Ouput value of the function to the function.
Definition: operator_util.h:42
SimpleOpScalarOption
options in the registry to set symbolic registration
Definition: operator_util.h:207
configuation of mxnet as well as basic data structure.
virtual TSelf & set_shape_function(UnaryShapeFunction fshapeinfer)=0
set shape inference function. Default: out_shape = in_shape
void(* BinaryFunction)(const TBlob &lhs, const TBlob &rhs, const EnvArguments &env, TBlob *ret, OpReqType req, RunContext ctx)
Binary function that takes lhs, rhs and save result to ret. The result container is pre-allocated wit...
Definition: operator_util.h:135
void(* BinaryGradFunctionT1)(const OutputGrad &out_grad, const Input0 &lhs, const Input1 &rhs, const EnvArguments &env, TBlob *lhs_grad, TBlob *rhs_grad, OpReqType req_lhs_grad, OpReqType req_rhs_grad, RunContext ctx)
Gradient function that takes inputs of function anod computes gradient wrt to input.
Definition: operator_util.h:182
std::vector< std::pair< std::string, std::string > > kwargs
keyword arguments
Definition: operator_util.h:54
Definition: operator_util.h:208
Definition: operator_util.h:214
in binary forward, allow inplace left operand with out
Definition: operator_util.h:201
SimpleOpRegEntry & __REGISTER_OR_FIND__(char const *name)
Internal function to register a name function under name.
in binary backward, allow inplace out_grad with lhs_grad
Definition: operator_util.h:203
virtual TSelf & set_gradient(int dev_mask, UnaryGradFunctionT0 fgrad, SimpleOpInplaceOption inplace_out_in_grad)=0
set gradient of the function of this function.
std::vector< Resource > resource
pointer to the resources requested
Definition: operator_util.h:56
SimpleOpRegEntry TSelf
declare self type
Definition: operator_util.h:222
virtual TSelf & set_function(int dev_mask, UnaryFunction funary, SimpleOpInplaceOption inplace_in_out, SimpleOpRegOption register_symbolic=kRegisterSymbolic)=0
set function of the function to be funary