mxnet
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
symbolic.h
Go to the documentation of this file.
1 
7 #ifndef MXNET_SYMBOLIC_H_
8 #define MXNET_SYMBOLIC_H_
9 
10 #include <dmlc/base.h>
11 #include <dmlc/json.h>
12 #include <vector>
13 #include <memory>
14 #include <map>
15 #include <string>
16 #include <utility>
17 #include "./base.h"
18 #include "./c_api.h"
19 #include "./ndarray.h"
20 #include "./operator.h"
21 
22 // check c++11
23 #if DMLC_USE_CXX11 == 0
24 #error "CXX11 was required for symbolic module"
25 #endif
26 
27 namespace mxnet {
32 class StaticGraph;
40 class Symbol {
41  public:
46  Symbol Copy() const;
51  void Print(std::ostream &os) const; // NOLINT(*)
58  std::vector<std::string> ListArguments() const;
60  std::vector<std::string> ListOutputs() const;
62  std::vector<std::string> ListAuxiliaryStates() const;
68  Symbol operator[] (size_t index) const;
77  void Compose(const std::vector<Symbol>& args,
78  const std::string& name);
88  void Compose(const std::unordered_map<std::string, Symbol>& kwargs,
89  const std::string& name);
96  bool GetName(std::string* out);
104  void SetAttr(const std::string &key, const std::string& value);
113  bool GetAttr(const std::string& key, std::string* out);
120  std::map<std::string, std::string> ListAttr();
127  std::map<std::string, std::string> ListAttrShallow();
134  Symbol operator () (const std::vector<Symbol>& args, const std::string& name) const;
141  Symbol operator () (const std::unordered_map<std::string, Symbol>& kwargs,
142  const std::string& name) const;
143  /*
144  * \brief Get all the internal nodes of the symbol.
145  * \return symbol A new symbol whose output contains all the outputs of the symbols
146  * Including input variables and intermediate outputs.
147  */
148  Symbol GetInternals() const;
154  Symbol Grad(const std::vector<std::string>& wrt) const;
171  bool InferShape(std::vector<TShape> *arg_shapes,
172  std::vector<TShape> *out_shapes,
173  std::vector<TShape> *aux_shapes,
174  bool partial_infer = false) const;
175 
186  bool InferShape(const std::unordered_map<std::string, TShape> &known_arg_shapes,
187  std::vector<TShape> *arg_shapes,
188  std::vector<TShape> *out_shapes,
189  std::vector<TShape> *aux_shapes,
190  bool partial_infer = false) const;
191 
207  bool InferType(std::vector<int> *arg_types,
208  std::vector<int> *out_types,
209  std::vector<int> *aux_types) const;
219  bool InferType(const std::unordered_map<std::string, int> &known_arg_types,
220  std::vector<int> *arg_types,
221  std::vector<int> *out_types,
222  std::vector<int> *aux_types) const;
227  void Save(dmlc::JSONWriter *writer) const;
232  void Load(dmlc::JSONReader *reader);
237  inline size_t NumOutputs() const {
238  return heads_.size();
239  }
248  static Symbol Create(OperatorProperty *op);
254  static Symbol CreateGroup(const std::vector<Symbol> &symbols);
260  static Symbol CreateVariable(const std::string &name);
261 
262  protected:
263  // Declare node, internal data structure.
264  struct Node;
266  struct DataEntry {
268  std::shared_ptr<Node> source;
270  uint32_t index;
274  DataEntry(std::shared_ptr<Node> source, uint32_t index)
275  : source(source), index(index) {}
276  };
281  std::vector<DataEntry> heads_;
282 
283  private:
285  inline bool is_atomic() const;
295  template<typename FVisit>
296  inline void DFSVisit(FVisit fvisit) const;
302  int FindDuplicateArgs(std::unordered_map<std::string, int> *out) const;
308  void ToStaticGraph(StaticGraph *out_graph) const;
314  void FromStaticGraph(const StaticGraph &graph);
316  friend class StaticGraph;
317 };
318 
323 class Executor {
324  public:
326  virtual ~Executor() {}
331  virtual void Forward(bool is_train) = 0;
340  virtual void PartialForward(bool is_train, int step, int *step_left) = 0;
350  virtual void Backward(const std::vector<NDArray> &head_grads) = 0;
355  virtual void Print(std::ostream &os) const {} // NOLINT(*)
360  virtual const std::vector<NDArray> &outputs() const = 0;
375  static Executor *Bind(Symbol symbol,
376  const Context& default_ctx,
377  const std::map<std::string, Context>& group2ctx,
378  const std::vector<NDArray> &in_args,
379  const std::vector<NDArray> &arg_grad_store,
380  const std::vector<OpReqType> &grad_req_type,
381  const std::vector<NDArray> &aux_states,
382  Executor* shared_exec = NULL);
386  typedef std::function<void(const char*, void*)> MonitorCallback;
390  virtual void SetMonitorCallback(const MonitorCallback& callback) {}
391 }; // class operator
392 } // namespace mxnet
393 #endif // MXNET_SYMBOLIC_H_
Executor of a computation graph. Executor can be created by Binding a symbol.
Definition: symbolic.h:323
uint32_t index
index of output from the source.
Definition: symbolic.h:270
C API of mxnet.
virtual ~Executor()
destructor
Definition: symbolic.h:326
std::vector< DataEntry > heads_
the head nodes of Symbols This head is only effective when
Definition: symbolic.h:281
Symbol operator()(const std::vector< Symbol > &args, const std::string &name) const
Apply the symbol as a function, compose with arguments.
static Symbol CreateGroup(const std::vector< Symbol > &symbols)
create equivalence of symbol by grouping the symbols together
static Executor * Bind(Symbol symbol, const Context &default_ctx, const std::map< std::string, Context > &group2ctx, const std::vector< NDArray > &in_args, const std::vector< NDArray > &arg_grad_store, const std::vector< OpReqType > &grad_req_type, const std::vector< NDArray > &aux_states, Executor *shared_exec=NULL)
Create an operator by bind symbol with context and arguments. If user do not want to compute the grad...
std::vector< std::string > ListAuxiliaryStates() const
bool InferShape(std::vector< TShape > *arg_shapes, std::vector< TShape > *out_shapes, std::vector< TShape > *aux_shapes, bool partial_infer=false) const
infer the shapes of outputs and unknown input arguments
std::function< void(const char *, void *)> MonitorCallback
the prototype of user-defined monitor callback
Definition: symbolic.h:386
Symbol Copy() const
copy the symbol
static Symbol CreateVariable(const std::string &name)
create variable symbol node
void Save(dmlc::JSONWriter *writer) const
interface for json serialization.
std::map< std::string, std::string > ListAttrShallow()
Get attribute dictionary from the symbol. This only works for symbol with outputs from single operato...
Symbol Grad(const std::vector< std::string > &wrt) const
get the gradient graph
virtual void Print(std::ostream &os) const
print the execution plan info to output stream.
Definition: symbolic.h:355
Symbol operator[](size_t index) const
get the index th element from the returned tuple.
virtual const std::vector< NDArray > & outputs() const =0
get array of outputs in the executor.
friend class StaticGraph
let static graph know the contents
Definition: symbolic.h:316
static Symbol Create(OperatorProperty *op)
create Symbol by wrapping OperatorProperty This function takes the ownership of op ...
virtual void SetMonitorCallback(const MonitorCallback &callback)
Install a callback to notify the completion of operation.
Definition: symbolic.h:390
std::vector< std::string > ListArguments() const
List the arguments names.
NDArray interface that handles array arithematics.
virtual void Forward(bool is_train)=0
Perform a Forward operation of Operator After this operation, user can get the result by using functi...
Operator interface of mxnet.
an entry that represents output data from a node
Definition: symbolic.h:266
void SetAttr(const std::string &key, const std::string &value)
set additional attributes of the symbol, This only works for symbol with outputs from single operator...
void Print(std::ostream &os) const
print the symbol info to output stream.
virtual void PartialForward(bool is_train, int step, int *step_left)=0
Perform a Partial Forward operation of Operator. Only issue operation specified by step...
void Load(dmlc::JSONReader *reader)
interface for json serialization.
Symbol GetInternals() const
void Compose(const std::vector< Symbol > &args, const std::string &name)
Compose the symbol with arguments, this changes current symbol.
Symbol is used to represent dynamically generated symbolic computation graph.
Definition: symbolic.h:40
std::map< std::string, std::string > ListAttr()
Get attribute dictionary from the symbol and all children. Each attribute name is pre-pended with the...
std::shared_ptr< Node > source
the source node of this data
Definition: symbolic.h:268
OperatorProperty is a object that stores all information about Operator. It also contains method to g...
Definition: operator.h:165
bool InferType(std::vector< int > *arg_types, std::vector< int > *out_types, std::vector< int > *aux_types) const
infer the types of outputs and unknown input arguments
virtual void Backward(const std::vector< NDArray > &head_grads)=0
Perform a Backward operation of the Operator. This must be called after Forward. After this operation...
configuation of mxnet as well as basic data structure.
Context information about the execution enviroment.
Definition: base.h:90
DataEntry()
enabled default copy constructor
Definition: symbolic.h:272
bool GetName(std::string *out)
Get name from the symbol. This only works for symbol with outputs from single operators. For grouped sybmbol, an error will be raised.
DataEntry(std::shared_ptr< Node > source, uint32_t index)
constructor from index
Definition: symbolic.h:274
size_t NumOutputs() const
get number of outputs of this symbol
Definition: symbolic.h:237
bool GetAttr(const std::string &key, std::string *out)
Get attributes from the symbol. This only works for symbol with outputs from single operators...
std::vector< std::string > ListOutputs() const