7 #ifndef MXNET_OPTIMIZER_H_
8 #define MXNET_OPTIMIZER_H_
10 #include <dmlc/base.h>
11 #include <dmlc/logging.h>
12 #include <dmlc/registry.h>
13 #include <mshadow/tensor.h>
41 virtual void Init(
const std::vector<std::pair<std::string, std::string> >& kwargs) = 0;
59 const NDArray *grad,
const float lr,
const float wd) = 0;
76 :
public dmlc::FunctionRegEntryBase<OptimizerReg,
93 #define MXNET_REGISTER_OPTIMIZER(name, OptimizerType) \
94 DMLC_REGISTRY_REGISTER(::mxnet::OptimizerReg, OptimizerReg, name) \
95 .set_body([]() { return new OptimizerType(); })
97 #endif // DMLC_USE_CXX11
100 #endif // MXNET_OPTIMIZER_H_
static Optimizer * Create(const char *type_name)
create Optimizer
Definition: optimizer.h:30
virtual ~Optimizer()
virtual destructor
Definition: optimizer.h:35
virtual void Update(const int index, NDArray *weight, const NDArray *grad, const float lr, const float wd)=0
Update a weight with gradient.
virtual void CreateState(const int index, const NDArray *weight)=0
Create aux state for weigth with index.
NDArray interface that handles array arithematics.
std::function< Optimizer *()> OptimizerFactory
typedef the factory function of Optimizer
Definition: optimizer.h:71
Global resource allocation handling.
Registry entry for Optimizer factory functions.
Definition: optimizer.h:75
configuation of mxnet as well as basic data structure.
virtual void Init(const std::vector< std::pair< std::string, std::string > > &kwargs)=0
Initialize the Optimizer by setting the parameters This function need to be called before all other f...
ndarray interface
Definition: ndarray.h:31