mxnet
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
optimizer.h
Go to the documentation of this file.
1 
7 #ifndef MXNET_OPTIMIZER_H_
8 #define MXNET_OPTIMIZER_H_
9 
10 #include <dmlc/base.h>
11 #include <dmlc/logging.h>
12 #include <dmlc/registry.h>
13 #include <mshadow/tensor.h>
14 #include <string>
15 #include <vector>
16 #include <utility>
17 #include "./base.h"
18 #include "./resource.h"
19 
20 #if DMLC_USE_CXX11
21 #include <mxnet/ndarray.h>
22 #endif
23 
24 namespace mxnet {
25 
26 #if !DMLC_USE_CXX11
27 class NDArray;
28 #endif
29 
30 class Optimizer {
31  public:
35  virtual ~Optimizer() {}
41  virtual void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) = 0;
42 
48  virtual void CreateState(const int index, const NDArray *weight) = 0;
49 
58  virtual void Update(const int index, NDArray *weight,
59  const NDArray *grad, const float lr, const float wd) = 0;
65  static Optimizer *Create(const char* type_name);
66 };
67 
68 #if DMLC_USE_CXX11
69 
71 typedef std::function<Optimizer *()> OptimizerFactory;
76  : public dmlc::FunctionRegEntryBase<OptimizerReg,
77  OptimizerFactory> {
78 };
79 
80 //--------------------------------------------------------------
81 // The following part are API Registration of Optimizers
82 //--------------------------------------------------------------
93 #define MXNET_REGISTER_OPTIMIZER(name, OptimizerType) \
94  DMLC_REGISTRY_REGISTER(::mxnet::OptimizerReg, OptimizerReg, name) \
95  .set_body([]() { return new OptimizerType(); })
96 
97 #endif // DMLC_USE_CXX11
98 
99 } // namespace mxnet
100 #endif // MXNET_OPTIMIZER_H_
static Optimizer * Create(const char *type_name)
create Optimizer
Definition: optimizer.h:30
virtual ~Optimizer()
virtual destructor
Definition: optimizer.h:35
virtual void Update(const int index, NDArray *weight, const NDArray *grad, const float lr, const float wd)=0
Update a weight with gradient.
virtual void CreateState(const int index, const NDArray *weight)=0
Create aux state for weigth with index.
NDArray interface that handles array arithematics.
std::function< Optimizer *()> OptimizerFactory
typedef the factory function of Optimizer
Definition: optimizer.h:71
Global resource allocation handling.
Registry entry for Optimizer factory functions.
Definition: optimizer.h:75
configuation of mxnet as well as basic data structure.
virtual void Init(const std::vector< std::pair< std::string, std::string > > &kwargs)=0
Initialize the Optimizer by setting the parameters This function need to be called before all other f...
ndarray interface
Definition: ndarray.h:31