8 #ifndef XGBOOST_LEARNER_H_ 9 #define XGBOOST_LEARNER_H_ 12 #include <rabit/rabit.h> 29 class GradientBooster;
73 virtual void Configure() = 0;
80 virtual void UpdateOneIter(
int iter, std::shared_ptr<DMatrix> train) = 0;
88 virtual void BoostOneIter(
int iter,
89 std::shared_ptr<DMatrix> train,
98 virtual std::string EvalOneIter(
int iter,
99 const std::vector<std::shared_ptr<DMatrix>>& data_sets,
100 const std::vector<std::string>& data_names) = 0;
114 virtual void Predict(std::shared_ptr<DMatrix> data,
117 unsigned ntree_limit = 0,
118 bool training =
false,
119 bool pred_leaf =
false,
120 bool pred_contribs =
false,
121 bool approx_contribs =
false,
122 bool pred_interactions =
false) = 0;
134 virtual void InplacePredict(dmlc::any
const& x, std::string
const& type,
137 uint32_t layer_begin = 0, uint32_t layer_end = 0) = 0;
139 void LoadModel(
Json const& in)
override = 0;
140 void SaveModel(
Json* out)
const override = 0;
142 virtual void LoadModel(dmlc::Stream* fi) = 0;
143 virtual void SaveModel(dmlc::Stream* fo)
const = 0;
150 virtual void SetParams(
Args const& args) = 0;
159 virtual void SetParam(
const std::string& key,
const std::string& value) = 0;
165 virtual uint32_t GetNumFeature() = 0;
175 virtual void SetAttr(
const std::string& key,
const std::string& value) = 0;
183 virtual bool GetAttr(
const std::string& key, std::string* out)
const = 0;
189 virtual bool DelAttr(
const std::string& key) = 0;
194 virtual std::vector<std::string> GetAttrNames()
const = 0;
198 bool AllowLazyCheckPoint()
const;
206 virtual std::vector<std::string> DumpModel(
const FeatureMap& fmap,
208 std::string format) = 0;
216 static Learner* Create(
const std::vector<std::shared_ptr<DMatrix> >& cache_data);
223 virtual const std::map<std::string, std::string>& GetConfigurationArguments()
const = 0;
227 std::unique_ptr<ObjFunction>
obj_;
229 std::unique_ptr<GradientBooster>
gbm_;
236 struct LearnerModelParamLegacy;
245 uint32_t num_feature { 0 };
247 uint32_t num_output_group { 0 };
258 #endif // XGBOOST_LEARNER_H_ Interface of predictor, performs predictions for a gradient booster.
float bst_float
float type, used for storing statistics
Definition: base.h:111
std::vector< std::unique_ptr< Metric > > metrics_
The evaluation metrics used to evaluate the model.
Definition: learner.h:231
Definition: learner.h:241
Definition: host_device_vector.h:86
std::vector< std::pair< std::string, std::string > > Args
Definition: base.h:253
Definition: generic_parameters.h:14
Defines the abstract interface for different components in XGBoost.
std::unique_ptr< GradientBooster > gbm_
The gradient booster used by the model.
Definition: learner.h:229
A device-and-host vector abstraction layer.
bool Initialized() const
Definition: learner.h:254
Feature map data structure to help text model dump. TODO(tqchen) consider make it even more lightweig...
Definition: feature_map.h:22
PredictionCacheEntry prediction_entry
Definition: learner.h:46
std::vector< bst_float > ret_vec_float
returning float vector.
Definition: learner.h:43
entry to to easily hold returning information
Definition: learner.h:35
std::vector< GradientPair > tmp_gpair
temp variable of gradient pairs.
Definition: learner.h:45
std::unique_ptr< ObjFunction > obj_
objective function
Definition: learner.h:227
Feature map data structure to help visualization and model dump.
namespace of xgboost
Definition: base.h:102
std::vector< const char * > ret_vec_charp
result holder for returning string pointers
Definition: learner.h:41
defines configuration macros of xgboost.
Learner class that does training and prediction. This is the user facing module of xgboost training...
Definition: learner.h:66
Data structure representing JSON format.
Definition: json.h:326
std::string ret_str
result holder for returning string
Definition: learner.h:37
std::vector< std::string > ret_vec_str
result holder for returning strings
Definition: learner.h:39
Contains pointer to input matrix and associated cached predictions.
Definition: predictor.h:35
GenericParameter generic_parameters_
Training parameter.
Definition: learner.h:233