7 #ifndef MXNET_SYMBOLIC_H_
8 #define MXNET_SYMBOLIC_H_
10 #include <dmlc/base.h>
11 #include <dmlc/json.h>
23 #if DMLC_USE_CXX11 == 0
24 #error "CXX11 was required for symbolic module"
51 void Print(std::ostream &os)
const;
77 void Compose(
const std::vector<Symbol>& args,
78 const std::string& name);
88 void Compose(
const std::unordered_map<std::string, Symbol>& kwargs,
89 const std::string& name);
104 void SetAttr(
const std::string &key,
const std::string& value);
113 bool GetAttr(
const std::string& key, std::string* out);
120 std::map<std::string, std::string>
ListAttr();
134 Symbol operator () (
const std::vector<Symbol>& args,
const std::string& name)
const;
142 const std::string& name)
const;
154 Symbol Grad(
const std::vector<std::string>& wrt)
const;
171 bool InferShape(std::vector<TShape> *arg_shapes,
172 std::vector<TShape> *out_shapes,
173 std::vector<TShape> *aux_shapes,
174 bool partial_infer =
false)
const;
186 bool InferShape(
const std::unordered_map<std::string, TShape> &known_arg_shapes,
187 std::vector<TShape> *arg_shapes,
188 std::vector<TShape> *out_shapes,
189 std::vector<TShape> *aux_shapes,
190 bool partial_infer =
false)
const;
207 bool InferType(std::vector<int> *arg_types,
208 std::vector<int> *out_types,
209 std::vector<int> *aux_types)
const;
219 bool InferType(
const std::unordered_map<std::string, int> &known_arg_types,
220 std::vector<int> *arg_types,
221 std::vector<int> *out_types,
222 std::vector<int> *aux_types)
const;
227 void Save(dmlc::JSONWriter *writer)
const;
232 void Load(dmlc::JSONReader *reader);
275 : source(source), index(index) {}
285 inline bool is_atomic()
const;
295 template<
typename FVisit>
296 inline void DFSVisit(FVisit fvisit)
const;
302 int FindDuplicateArgs(std::unordered_map<std::string, int> *out)
const;
331 virtual void Forward(
bool is_train) = 0;
340 virtual void PartialForward(
bool is_train,
int step,
int *step_left) = 0;
350 virtual void Backward(
const std::vector<NDArray> &head_grads) = 0;
355 virtual void Print(std::ostream &os)
const {}
360 virtual const std::vector<NDArray> &
outputs()
const = 0;
377 const std::map<std::string, Context>& group2ctx,
378 const std::vector<NDArray> &in_args,
379 const std::vector<NDArray> &arg_grad_store,
380 const std::vector<OpReqType> &grad_req_type,
381 const std::vector<NDArray> &aux_states,
393 #endif // MXNET_SYMBOLIC_H_
Executor of a computation graph. Executor can be created by Binding a symbol.
Definition: symbolic.h:323
uint32_t index
index of output from the source.
Definition: symbolic.h:270
virtual ~Executor()
destructor
Definition: symbolic.h:326
std::vector< DataEntry > heads_
the head nodes of Symbols This head is only effective when
Definition: symbolic.h:281
Symbol operator()(const std::vector< Symbol > &args, const std::string &name) const
Apply the symbol as a function, compose with arguments.
static Symbol CreateGroup(const std::vector< Symbol > &symbols)
create equivalence of symbol by grouping the symbols together
static Executor * Bind(Symbol symbol, const Context &default_ctx, const std::map< std::string, Context > &group2ctx, const std::vector< NDArray > &in_args, const std::vector< NDArray > &arg_grad_store, const std::vector< OpReqType > &grad_req_type, const std::vector< NDArray > &aux_states, Executor *shared_exec=NULL)
Create an operator by bind symbol with context and arguments. If user do not want to compute the grad...
std::vector< std::string > ListAuxiliaryStates() const
bool InferShape(std::vector< TShape > *arg_shapes, std::vector< TShape > *out_shapes, std::vector< TShape > *aux_shapes, bool partial_infer=false) const
infer the shapes of outputs and unknown input arguments
std::function< void(const char *, void *)> MonitorCallback
the prototype of user-defined monitor callback
Definition: symbolic.h:386
Symbol Copy() const
copy the symbol
static Symbol CreateVariable(const std::string &name)
create variable symbol node
void Save(dmlc::JSONWriter *writer) const
interface for json serialization.
std::map< std::string, std::string > ListAttrShallow()
Get attribute dictionary from the symbol. This only works for symbol with outputs from single operato...
Symbol Grad(const std::vector< std::string > &wrt) const
get the gradient graph
virtual void Print(std::ostream &os) const
print the execution plan info to output stream.
Definition: symbolic.h:355
Symbol operator[](size_t index) const
get the index th element from the returned tuple.
virtual const std::vector< NDArray > & outputs() const =0
get array of outputs in the executor.
friend class StaticGraph
let static graph know the contents
Definition: symbolic.h:316
static Symbol Create(OperatorProperty *op)
create Symbol by wrapping OperatorProperty This function takes the ownership of op ...
virtual void SetMonitorCallback(const MonitorCallback &callback)
Install a callback to notify the completion of operation.
Definition: symbolic.h:390
std::vector< std::string > ListArguments() const
List the arguments names.
NDArray interface that handles array arithematics.
virtual void Forward(bool is_train)=0
Perform a Forward operation of Operator After this operation, user can get the result by using functi...
Operator interface of mxnet.
an entry that represents output data from a node
Definition: symbolic.h:266
void SetAttr(const std::string &key, const std::string &value)
set additional attributes of the symbol, This only works for symbol with outputs from single operator...
void Print(std::ostream &os) const
print the symbol info to output stream.
virtual void PartialForward(bool is_train, int step, int *step_left)=0
Perform a Partial Forward operation of Operator. Only issue operation specified by step...
void Load(dmlc::JSONReader *reader)
interface for json serialization.
Symbol GetInternals() const
void Compose(const std::vector< Symbol > &args, const std::string &name)
Compose the symbol with arguments, this changes current symbol.
Symbol is used to represent dynamically generated symbolic computation graph.
Definition: symbolic.h:40
std::map< std::string, std::string > ListAttr()
Get attribute dictionary from the symbol and all children. Each attribute name is pre-pended with the...
std::shared_ptr< Node > source
the source node of this data
Definition: symbolic.h:268
OperatorProperty is a object that stores all information about Operator. It also contains method to g...
Definition: operator.h:165
bool InferType(std::vector< int > *arg_types, std::vector< int > *out_types, std::vector< int > *aux_types) const
infer the types of outputs and unknown input arguments
virtual void Backward(const std::vector< NDArray > &head_grads)=0
Perform a Backward operation of the Operator. This must be called after Forward. After this operation...
configuation of mxnet as well as basic data structure.
Context information about the execution enviroment.
Definition: base.h:90
DataEntry()
enabled default copy constructor
Definition: symbolic.h:272
bool GetName(std::string *out)
Get name from the symbol. This only works for symbol with outputs from single operators. For grouped sybmbol, an error will be raised.
DataEntry(std::shared_ptr< Node > source, uint32_t index)
constructor from index
Definition: symbolic.h:274
size_t NumOutputs() const
get number of outputs of this symbol
Definition: symbolic.h:237
bool GetAttr(const std::string &key, std::string *out)
Get attributes from the symbol. This only works for symbol with outputs from single operators...
std::vector< std::string > ListOutputs() const