7 #ifndef MXNET_OPERATOR_H_
8 #define MXNET_OPERATOR_H_
10 #include <dmlc/base.h>
11 #include <dmlc/json.h>
12 #include <dmlc/logging.h>
13 #include <dmlc/registry.h>
58 template<
typename xpu>
109 const std::vector<TBlob> &in_data,
110 const std::vector<OpReqType> &req,
111 const std::vector<TBlob> &out_data,
112 const std::vector<TBlob> &aux_states) = 0;
142 const std::vector<TBlob> &out_grad,
143 const std::vector<TBlob> &in_data,
144 const std::vector<TBlob> &out_data,
145 const std::vector<OpReqType> &req,
146 const std::vector<TBlob> &in_grad,
147 const std::vector<TBlob> &aux_states) {
148 LOG(FATAL) <<
"Backward is not implemented";
176 virtual void Init(
const std::vector<std::pair<std::string, std::string> >& kwargs) = 0;
181 virtual std::map<std::string, std::string>
GetParams()
const = 0;
239 virtual bool InferShape(std::vector<TShape> *in_shape,
240 std::vector<TShape> *out_shape,
241 std::vector<TShape> *aux_shape)
const = 0;
260 std::vector<int> *out_type,
261 std::vector<int> *aux_type)
const {
264 for (
unsigned i = 0; i < in_type->size(); ++i) {
265 CHECK(in_type->at(i) == mshadow::default_type_flag ||
266 in_type->at(i) == -1) <<
"Unsupported data type " << in_type->at(i);
269 for (
int i = 0; i < n_in; ++i ) in_type->push_back(mshadow::default_type_flag);
273 for (
int i = 0; i < n_out; ++i ) out_type->push_back(mshadow::default_type_flag);
277 for (
int i = 0; i < n_aux; ++i ) aux_type->push_back(mshadow::default_type_flag);
297 std::vector<int> *in_type)
const {
298 std::vector<int> out_type, aux_type;
299 std::vector<TShape> out_shape, aux_shape;
304 CHECK(
InferType(in_type, &out_type, &aux_type));
305 CHECK(
InferShape(in_shape, &out_shape, &aux_shape));
325 const std::vector<TShape> &in_shape)
const {
326 return std::vector<ResourceRequest>();
336 const std::vector<TShape> &in_shape)
const {
337 return std::vector<ResourceRequest>();
362 const std::vector<int> &out_grad,
363 const std::vector<int> &in_data,
364 const std::vector<int> &out_data)
const {
367 std::vector<int> ret = out_grad;
368 ret.insert(ret.end(), in_data.begin(), in_data.end());
369 ret.insert(ret.end(), out_data.begin(), out_data.end());
394 const std::vector<int> &in_data,
395 const std::vector<void*> &out_data)
const {
396 return std::vector<std::pair<int, void*> >();
425 const std::vector<int> &out_grad,
426 const std::vector<int> &in_data,
427 const std::vector<int> &out_data,
428 const std::vector<void*> &in_grad)
const {
429 return std::vector<std::pair<int, void*> >();
445 const std::vector<T> &in_data,
446 const std::vector<T> &out_data)
const {
448 std::vector<int> out_grad_index(out_grad.size());
449 std::vector<int> in_data_index(in_data.size());
450 std::vector<int> out_data_index(out_data.size());
451 for (
size_t i = 0; i < out_grad_index.size(); ++i) {
452 out_grad_index[i] = counter++;
454 for (
size_t i = 0; i < in_data_index.size(); ++i) {
455 in_data_index[i] = counter++;
457 for (
size_t i = 0; i < out_data_index.size(); ++i) {
458 out_data_index[i] = counter++;
460 std::vector<T> all_data;
461 all_data.insert(all_data.end(), out_grad.begin(), out_grad.end());
462 all_data.insert(all_data.end(), in_data.begin(), in_data.end());
463 all_data.insert(all_data.end(), out_data.begin(), out_data.end());
466 out_grad_index, in_data_index, out_data_index);
468 std::vector<T> ret(ret_index.size());
469 for (
size_t i = 0; i < ret_index.size(); ++i) {
470 ret[i] = all_data[ret_index[i]];
488 :
public dmlc::FunctionRegEntryBase<OperatorPropertyReg,
489 OperatorPropertyFactory> {
513 CHECK_EQ(this->name, type)
514 <<
"Register Name and TypeString mismatch, name=\"" << this->name <<
"\","
515 <<
" but TypeString=\"" << type <<
"\"";
537 #define MXNET_REGISTER_OP_PROPERTY(name, OperatorPropertyType) \
538 DMLC_REGISTRY_REGISTER(::mxnet::OperatorPropertyReg, OperatorPropertyReg, name) \
539 .set_body([]() { return new OperatorPropertyType(); }) \
540 .set_return_type("Symbol") \
543 #endif // DMLC_USE_CXX11
545 #endif // MXNET_OPERATOR_H_
virtual std::vector< int > DeclareBackwardDependency(const std::vector< int > &out_grad, const std::vector< int > &in_data, const std::vector< int > &out_data) const
Declare the input requirement of Backward pass.
Definition: operator.h:361
OperatorPropertyReg & check_name()
Check if TypeString of the type matches the registered name.
Definition: operator.h:509
no operation, do not write anything
Definition: operator.h:25
write gradient to provided space
Definition: operator.h:27
Forward/Backward are synchronize calls.
Definition: operator.h:81
mshadow::Stream< xpu > * get_stream() const
get mshadow stream from Context
Definition: base.h:192
virtual void Forward(const OpContext &ctx, const std::vector< TBlob > &in_data, const std::vector< OpReqType > &req, const std::vector< TBlob > &out_data, const std::vector< TBlob > &aux_states)=0
perform a forward operation of Operator, save the output to TBlob.
virtual void Backward(const OpContext &ctx, const std::vector< TBlob > &out_grad, const std::vector< TBlob > &in_data, const std::vector< TBlob > &out_data, const std::vector< OpReqType > &req, const std::vector< TBlob > &in_grad, const std::vector< TBlob > &aux_states)
Perform a Backward Operation, write gradient to the in_grad.
Definition: operator.h:141
virtual ~Operator()
destructor
Definition: operator.h:96
std::function< OperatorProperty *()> OperatorPropertyFactory
typedef the factory function of operator property
Definition: operator.h:483
std::vector< T > BackwardInputs(const std::vector< T > &out_grad, const std::vector< T > &in_data, const std::vector< T > &out_data) const
Get Backward Input Dependency for generic types of data. Normally T can be pointer of Symbol::DataEnt...
Definition: operator.h:444
virtual std::vector< ResourceRequest > BackwardResource(const std::vector< TShape > &in_shape) const
Decalre additional resource required in backward pass. These additional resources will be presented i...
Definition: operator.h:335
static OperatorProperty * Create(const char *type_name)
create OperatorProperty
execution time context. The information needed in runtime for actual execution.
Definition: base.h:181
OperatorPropertyReg & set_key_var_num_args(const std::string &key)
Set key_var_num_args When this is set, the API caller is required to pass in a argument with key=key_...
Definition: operator.h:502
virtual ~OperatorProperty()
virtual destructor
Definition: operator.h:170
engine::CallbackOnComplete async_on_complete
the callback when operation completes, used by asynchronize ops
Definition: operator.h:50
virtual std::vector< std::string > ListOutputs() const
Get name of output values of Operator.
Definition: operator.h:193
All the possible information needed by Operator.Forward and Backward This is the superset of RunConte...
Definition: operator.h:44
int is_train
whether it is training phase
Definition: operator.h:46
Operator interface. Operator defins basic operation unit of optimized computation graph in mxnet...
Definition: operator.h:76
virtual std::map< std::string, std::string > GetParams() const =0
Get a map representation of internal parameters. This can be used by Init to recover the state of Ope...
Global resource allocation handling.
Forward/Backward are asynchronize, will call OpContext.async_on_complete when operation finishes...
Definition: operator.h:86
virtual std::vector< std::pair< int, void * > > ForwardInplaceOption(const std::vector< int > &in_data, const std::vector< void * > &out_data) const
Get possible forward inplace options. This function enables optimization to reuse memory of inputs in...
Definition: operator.h:393
OperatorProperty is a object that stores all information about Operator. It also contains method to g...
Definition: operator.h:165
virtual Operator * CreateOperatorEx(Context ctx, std::vector< TShape > *in_shape, std::vector< int > *in_type) const
Create a Operator on specific context and input shape/type.
Definition: operator.h:296
virtual std::vector< std::pair< int, void * > > BackwardInplaceOption(const std::vector< int > &out_grad, const std::vector< int > &in_data, const std::vector< int > &out_data, const std::vector< void * > &in_grad) const
Get possible backward inplace options. This function enables optimization to reuse memory of inputs i...
Definition: operator.h:424
virtual bool InferType(std::vector< int > *in_type, std::vector< int > *out_type, std::vector< int > *aux_type) const
infer the data types of outputs and unknown input arguments
Definition: operator.h:259
virtual std::vector< std::string > ListAuxiliaryStates() const
Get name of auxilary states of Operator.
Definition: operator.h:200
virtual std::vector< std::string > ListArguments() const
Get input arguments of the Operator.
Definition: operator.h:186
perform an inplace write, Target shares memory with one of input arguments. This option only happen w...
Definition: operator.h:33
OpReqType
operation request type to Forward and Backward
Definition: operator.h:23
ExecType
the execution type of the operator
Definition: operator.h:79
Cross device copy operation, this is a special operator That indicates copy across devices...
Definition: operator.h:93
std::vector< Resource > requested
Resources requested by the operator.
Definition: operator.h:52
RunContext run_ctx
RunContext related resources.
Definition: operator.h:48
virtual void Init(const std::vector< std::pair< std::string, std::string > > &kwargs)=0
Initialize the Operator by setting the parameters This function need to be called before all other fu...
std::string key_var_num_args
The key num_args name.
Definition: operator.h:520
OnComplete Callback to the engine, called by AsyncFn when action completes.
Definition: engine.h:36
configuation of mxnet as well as basic data structure.
virtual int NumVisibleOutputs() const
get number of visible return values during Symbol creation. If NumVisibleOutputs() = k...
Definition: operator.h:219
Registry entry for OperatorProperty factory functions.
Definition: operator.h:487
virtual ExecType exec_type() const
Definition: operator.h:151
add to the provided space
Definition: operator.h:35
virtual bool InferShape(std::vector< TShape > *in_shape, std::vector< TShape > *out_shape, std::vector< TShape > *aux_shape) const =0
infer the shapes of outputs and unknown input arguments
virtual std::vector< ResourceRequest > ForwardResource(const std::vector< TShape > &in_shape) const
Declare additional resource required in forward pass. These additional resources will be presented in...
Definition: operator.h:324
virtual Operator * CreateOperator(Context ctx) const =0
Create a Operator on specific context.
virtual OperatorProperty * Copy() const =0
Copy this OperatorProperty.
virtual int NumOutputs() const
Definition: operator.h:204
Context information about the execution enviroment.
Definition: base.h:90
mshadow::Stream< xpu > * get_stream() const
get mshadow stream from Context
Definition: operator.h:59
virtual std::string TypeString() const =0
return the type string of the Operator subclasses override this function.