mxnet
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
operator_util.h
Go to the documentation of this file.
1 
11 #ifndef MXNET_OPERATOR_UTIL_H_
12 #define MXNET_OPERATOR_UTIL_H_
13 
14 #include <dmlc/registry.h>
15 #include <dmlc/parameter.h>
16 #include <map>
17 #include <vector>
18 #include <string>
19 #include <utility>
20 #include "./base.h"
21 #include "./operator.h"
22 
23 #if DMLC_USE_CXX11
24 #include <functional>
25 #endif
26 
27 namespace mxnet {
29 namespace op {
34 };
35 
40 
45 
50 struct EnvArguments {
54  std::vector<std::pair<std::string, std::string> > kwargs;
56  std::vector<Resource> resource;
57 };
58 
68 typedef void (*UnaryFunction)(const TBlob& src,
69  const EnvArguments& env,
70  TBlob* ret,
71  OpReqType req,
72  RunContext ctx);
79 typedef TShape (*UnaryShapeFunction)(const TShape& src,
80  const EnvArguments& env);
81 
90 typedef void (*UnaryGradFunctionT0)(const OutputGrad& out_grad,
91  const EnvArguments& env,
92  TBlob* in_grad,
93  OpReqType req,
94  RunContext ctx);
104 typedef void (*UnaryGradFunctionT1)(const OutputGrad& out_grad,
105  const OutputValue& out_value,
106  const EnvArguments& env,
107  TBlob* in_grad,
108  OpReqType req,
109  RunContext ctx);
119 typedef void (*UnaryGradFunctionT2)(const OutputGrad& out_grad,
120  const Input0& in_data0,
121  const EnvArguments& env,
122  TBlob* in_grad,
123  OpReqType req,
124  RunContext ctx);
135 typedef void (*BinaryFunction)(const TBlob& lhs,
136  const TBlob& rhs,
137  const EnvArguments& env,
138  TBlob* ret,
139  OpReqType req,
140  RunContext ctx);
141 
149 typedef TShape (*BinaryShapeFunction)(const TShape& lhs,
150  const TShape& rhs,
151  const EnvArguments& env);
163 typedef void (*BinaryGradFunctionT0)(const OutputGrad& out_grad,
164  const EnvArguments& env,
165  TBlob* lhs_grad,
166  TBlob* rhs_grad,
167  OpReqType req_lhs_grad,
168  OpReqType req_rhs_grad,
169  RunContext ctx);
182 typedef void (*BinaryGradFunctionT1)(const OutputGrad& out_grad,
183  const Input0& lhs,
184  const Input1& rhs,
185  const EnvArguments& env,
186  TBlob* lhs_grad,
187  TBlob* rhs_grad,
188  OpReqType req_lhs_grad,
189  OpReqType req_rhs_grad,
190  RunContext ctx);
191 
204 };
205 
210 };
211 
216 };
217 
220  public:
224  std::string name;
231  virtual TSelf& set_symbol_op_name(char const* symbol_name) = 0;
239  virtual TSelf& set_enable_scalar(
240  bool enable_scalar,
241  SimpleOpScalarOption type_mask = kArrayBeforeScalar) = 0;
248  virtual TSelf& set_enable_kwargs(bool enable_kwargs) = 0;
255  virtual TSelf& set_resource_request(
256  const std::vector<ResourceRequest>& reqs) = 0;
263  virtual TSelf& set_resource_request(ResourceRequest req) = 0;
269  virtual TSelf& set_shape_function(UnaryShapeFunction fshapeinfer) = 0;
275  virtual TSelf& set_shape_function(BinaryShapeFunction fshapeinfer) = 0;
283  virtual TSelf& set_function(
284  int dev_mask,
285  UnaryFunction funary,
286  SimpleOpInplaceOption inplace_in_out,
287  SimpleOpRegOption register_symbolic = kRegisterSymbolic) = 0;
295  virtual TSelf& set_function(
296  int dev_mask,
297  BinaryFunction fbinary,
298  SimpleOpInplaceOption inplace_lhs_out,
299  SimpleOpRegOption register_symbolic = kRegisterSymbolic) = 0;
306  virtual TSelf& set_gradient(int dev_mask,
307  UnaryGradFunctionT0 fgrad,
308  SimpleOpInplaceOption inplace_out_in_grad) = 0;
315  virtual TSelf& set_gradient(int dev_mask,
316  UnaryGradFunctionT1 fgrad,
317  SimpleOpInplaceOption inplace_out_in_grad) = 0;
324  virtual TSelf& set_gradient(int dev_mask,
325  UnaryGradFunctionT2 fgrad,
326  SimpleOpInplaceOption inplace_out_in_grad) = 0;
333  virtual TSelf& set_gradient(int dev_mask,
334  BinaryGradFunctionT0 fgrad,
335  SimpleOpInplaceOption inplace_out_lhs_grad) = 0;
342  virtual TSelf& set_gradient(int dev_mask,
343  BinaryGradFunctionT1 fgrad,
344  SimpleOpInplaceOption inplace_out_lhs_grad) = 0;
350  virtual TSelf& describe(const std::string &description) = 0;
357  virtual TSelf& add_arguments(const std::vector<dmlc::ParamFieldInfo> &args) = 0;
359  virtual ~SimpleOpRegEntry() {}
360 };
361 
364  public:
370  SimpleOpRegEntry &__REGISTER_OR_FIND__(char const* name);
376  inline static const SimpleOpRegEntry *Find(const std::string &name) {
377  return Get()->fmap_.at(name);
378  }
380  static SimpleOpRegistry* Get();
381 
382  private:
383  // destructor
384  ~SimpleOpRegistry();
386  std::map<std::string, SimpleOpRegEntry*> fmap_;
387 };
388 
397 #define ASSIGN_DISPATCH(out, req, exp) \
398  { \
399  switch (req) { \
400  case kNullOp: \
401  break; \
402  case kWriteTo: \
403  case kWriteInplace: \
404  (out) = (exp); \
405  break; \
406  case kAddTo: \
407  (out) += (exp); \
408  break; \
409  default: \
410  LOG(FATAL) << "not reached"; \
411  } \
412  }
413 
419 #define MXNET_RANGE_SWITCH(var, NDIM, ...) \
420  { \
421  switch (var) { \
422  case 1: \
423  { \
424  static const int NDIM = 1; \
425  {__VA_ARGS__} \
426  } \
427  break; \
428  case 2: \
429  { \
430  static const int NDIM = 2; \
431  {__VA_ARGS__} \
432  } \
433  break; \
434  case 3: \
435  { \
436  static const int NDIM = 3; \
437  {__VA_ARGS__} \
438  } \
439  break; \
440  case 4: \
441  { \
442  static const int NDIM = 4; \
443  {__VA_ARGS__} \
444  } \
445  break; \
446  case 5: \
447  { \
448  static const int NDIM = 5; \
449  {__VA_ARGS__} \
450  } \
451  break; \
452  default: \
453  LOG(FATAL) << "Only support ndim=1 to 5."; \
454  } \
455  }
456 
457 
458 //--------------------------------------------------------------
459 // The following part are API Registration of Simple Operators
460 //--------------------------------------------------------------
478 #define MXNET_REGISTER_SIMPLE_OP(Name, DEV) \
479  static ::mxnet::op::SimpleOpRegEntry & \
480  __make_ ## SimpleOpRegEntry ## _ ## Name ## __ ## DEV ##__ = \
481  ::mxnet::op::SimpleOpRegistry::Get()->__REGISTER_OR_FIND__(#Name)
482 
483 } // namespace op
484 } // namespace mxnet
485 #endif // MXNET_OPERATOR_UTIL_H_
Definition: operator_util.h:215
Gradient of output value.
Definition: operator_util.h:44
virtual TSelf & set_resource_request(const std::vector< ResourceRequest > &reqs)=0
set resource request By default there is no resource request. The resource will be presented in both ...
SimpleOpRegOption
options in the registry to set symbolic registration
Definition: operator_util.h:213
mshadow::TShape TShape
dynamic shape type
Definition: base.h:85
Definition: operator_util.h:209
TShape(* BinaryShapeFunction)(const TShape &lhs, const TShape &rhs, const EnvArguments &env)
Shape inference function to get the correct shape given source shapes.
Definition: operator_util.h:149
virtual TSelf & set_enable_kwargs(bool enable_kwargs)=0
set whether to enable kwargs A function cannot have both kwargs and scalar arguments. Default: this is set to false
virtual TSelf & set_symbol_op_name(char const *symbol_name)=0
set a seperate name for symbol This must be called before set_function. Default: this is set to be sa...
virtual TSelf & add_arguments(const std::vector< dmlc::ParamFieldInfo > &args)=0
Describe the function.
void(* UnaryGradFunctionT1)(const OutputGrad &out_grad, const OutputValue &out_value, const EnvArguments &env, TBlob *in_grad, OpReqType req, RunContext ctx)
Gradient function that takes output value of function and computes gradient wrt to input...
Definition: operator_util.h:104
registry for TBlob functions
Definition: operator_util.h:363
The resources that can be requested by Operator.
Definition: resource.h:18
mshadow::default_real_t real_t
data type that will be used to store ndarray
Definition: base.h:82
mshadow::TBlob TBlob
storage container type
Definition: base.h:87
std::string name
name of the operator
Definition: operator_util.h:224
in unary forward, allow inplace in with out
Definition: operator_util.h:197
void(* UnaryGradFunctionT2)(const OutputGrad &out_grad, const Input0 &in_data0, const EnvArguments &env, TBlob *in_grad, OpReqType req, RunContext ctx)
Gradient function that takes input value of function and computes gradient wrt to input...
Definition: operator_util.h:119
execution time context. The information needed in runtime for actual execution.
Definition: base.h:181
real_t scalar
scalar argument, if enabled
Definition: operator_util.h:52
static const SimpleOpRegEntry * Find(const std::string &name)
Find the entry with corresponding name.
Definition: operator_util.h:376
virtual ~SimpleOpRegEntry()
virtual destructor
Definition: operator_util.h:359
registry entry to register simple operators via functions.
Definition: operator_util.h:219
Operator interface of mxnet.
virtual TSelf & describe(const std::string &description)=0
Describe the function.
Second input to the function.
Definition: operator_util.h:39
TShape(* UnaryShapeFunction)(const TShape &src, const EnvArguments &env)
Shape inference function to get the correct shape given source.
Definition: operator_util.h:79
in unary backward, allow inplace out_grad with in_grad
Definition: operator_util.h:199
super class of all gradient function argument
Definition: operator_util.h:31
TBlob data
The real data.
Definition: operator_util.h:33
do not allow inplace in arguments
Definition: operator_util.h:195
void(* UnaryFunction)(const TBlob &src, const EnvArguments &env, TBlob *ret, OpReqType req, RunContext ctx)
Unary function that takes a src and save result to ret. The result container is pre-allocated with th...
Definition: operator_util.h:68
SimpleOpInplaceOption
options in the registry to set inplace of operator
Definition: operator_util.h:193
void(* UnaryGradFunctionT0)(const OutputGrad &out_grad, const EnvArguments &env, TBlob *in_grad, OpReqType req, RunContext ctx)
Gradient function that takes output value of function and computes gradient wrt to input...
Definition: operator_util.h:90
virtual TSelf & set_enable_scalar(bool enable_scalar, SimpleOpScalarOption type_mask=kArrayBeforeScalar)=0
set number of scalar arguments needed to be passed in env A function cannot have both kwargs and scal...
void(* BinaryGradFunctionT0)(const OutputGrad &out_grad, const EnvArguments &env, TBlob *lhs_grad, TBlob *rhs_grad, OpReqType req_lhs_grad, OpReqType req_rhs_grad, RunContext ctx)
Gradient function that takes only output gradient and computes gradient wrt to input. We support total gradient as a whole to make it easy to combine a few ops.
Definition: operator_util.h:163
static SimpleOpRegistry * Get()
OpReqType
operation request type to Forward and Backward
Definition: operator.h:23
Environment arguments that is used by the function. These can be things like scalar arguments when ad...
Definition: operator_util.h:50
Ouput value of the function to the function.
Definition: operator_util.h:42
SimpleOpScalarOption
options in the registry to set symbolic registration
Definition: operator_util.h:207
configuation of mxnet as well as basic data structure.
virtual TSelf & set_shape_function(UnaryShapeFunction fshapeinfer)=0
set shape inference function. Default: out_shape = in_shape
void(* BinaryFunction)(const TBlob &lhs, const TBlob &rhs, const EnvArguments &env, TBlob *ret, OpReqType req, RunContext ctx)
Binary function that takes lhs, rhs and save result to ret. The result container is pre-allocated wit...
Definition: operator_util.h:135
void(* BinaryGradFunctionT1)(const OutputGrad &out_grad, const Input0 &lhs, const Input1 &rhs, const EnvArguments &env, TBlob *lhs_grad, TBlob *rhs_grad, OpReqType req_lhs_grad, OpReqType req_rhs_grad, RunContext ctx)
Gradient function that takes inputs of function anod computes gradient wrt to input.
Definition: operator_util.h:182
First input to the function.
Definition: operator_util.h:37
std::vector< std::pair< std::string, std::string > > kwargs
keyword arguments
Definition: operator_util.h:54
Definition: operator_util.h:208
Definition: operator_util.h:214
in binary forward, allow inplace left operand with out
Definition: operator_util.h:201
SimpleOpRegEntry & __REGISTER_OR_FIND__(char const *name)
Internal function to register a name function under name.
in binary backward, allow inplace out_grad with lhs_grad
Definition: operator_util.h:203
virtual TSelf & set_gradient(int dev_mask, UnaryGradFunctionT0 fgrad, SimpleOpInplaceOption inplace_out_in_grad)=0
set gradient of the function of this function.
std::vector< Resource > resource
pointer to the resources requested
Definition: operator_util.h:56
SimpleOpRegEntry TSelf
declare self type
Definition: operator_util.h:222
virtual TSelf & set_function(int dev_mask, UnaryFunction funary, SimpleOpInplaceOption inplace_in_out, SimpleOpRegOption register_symbolic=kRegisterSymbolic)=0
set function of the function to be funary