7 #ifndef XGBOOST_TREE_MODEL_H_ 8 #define XGBOOST_TREE_MODEL_H_ 11 #include <dmlc/parameter.h> 15 #include <xgboost/logging.h> 35 struct TreeParam :
public dmlc::Parameter<TreeParam> {
56 static_assert(
sizeof(
TreeParam) == (31 + 6) *
sizeof(
int),
57 "TreeParam: 64 bit align");
60 deprecated_num_roots = 1;
66 DMLC_DECLARE_FIELD(num_nodes).set_lower_bound(1).set_default(1);
67 DMLC_DECLARE_FIELD(num_feature)
68 .describe(
"Number of features used in tree construction.");
69 DMLC_DECLARE_FIELD(num_deleted);
70 DMLC_DECLARE_FIELD(size_leaf_vector).set_lower_bound(0).set_default(0)
71 .describe(
"Size of leaf vector, reserved for vector tree");
91 int leaf_child_cnt {0};
95 loss_chg{loss_chg}, sum_hess{sum_hess}, base_weight{weight} {}
110 static constexpr uint32_t kDeletedNodeMarker = std::numeric_limits<uint32_t>::max();
118 static_assert(
sizeof(
Node) == 4 *
sizeof(
int) +
sizeof(Info),
119 "Node: 64 bit align");
121 Node(int32_t cleft, int32_t cright, int32_t parent,
122 uint32_t split_ind,
float split_cond,
bool default_left) :
123 parent_{parent}, cleft_{cleft}, cright_{cright} {
124 this->SetParent(parent_);
125 this->SetSplit(split_ind, split_cond, default_left);
134 return this->cright_;
138 return this->DefaultLeft() ? this->LeftChild() : this->RightChild();
142 return sindex_ & ((1U << 31) - 1U);
146 return (sindex_ >> 31) != 0;
150 return cleft_ == kInvalidNodeId;
154 return (this->info_).leaf_value;
158 return (this->info_).split_cond;
162 return parent_ & ((1U << 31) - 1);
166 return (parent_ & (1U << 31)) != 0;
170 return sindex_ == kDeletedNodeMarker;
195 bool default_left =
false) {
196 if (default_left) split_index |= (1U << 31);
197 this->sindex_ = split_index;
198 (this->info_).split_cond = split_cond;
207 (this->info_).leaf_value = value;
208 this->cleft_ = kInvalidNodeId;
209 this->cright_ = right;
213 this->sindex_ = kDeletedNodeMarker;
221 if (is_left_child) pidx |= (1U << 31);
222 this->parent_ = pidx;
225 return parent_ == b.parent_ && cleft_ == b.cleft_ &&
226 cright_ == b.cright_ && sindex_ == b.sindex_ &&
227 info_.leaf_value == b.info_.leaf_value;
241 int32_t parent_{kInvalidNodeId};
243 int32_t cleft_{kInvalidNodeId}, cright_{kInvalidNodeId};
256 CHECK(nodes_[nodes_[rid].LeftChild() ].IsLeaf());
257 CHECK(nodes_[nodes_[rid].RightChild()].IsLeaf());
258 this->DeleteNode(nodes_[rid].LeftChild());
259 this->DeleteNode(nodes_[rid].RightChild());
260 nodes_[rid].SetLeaf(value);
268 if (nodes_[rid].IsLeaf())
return;
269 if (!nodes_[nodes_[rid].LeftChild() ].IsLeaf()) {
270 CollapseToLeaf(nodes_[rid].LeftChild(), 0.0f);
272 if (!nodes_[nodes_[rid].RightChild() ].IsLeaf()) {
273 CollapseToLeaf(nodes_[rid].RightChild(), 0.0f);
275 this->ChangeToLeaf(rid, value);
286 for (
int i = 0; i < param.
num_nodes; i ++) {
287 nodes_[i].SetLeaf(0.0f);
288 nodes_[i].SetParent(kInvalidNodeId);
301 const std::vector<Node>&
GetNodes()
const {
return nodes_; }
316 void Load(dmlc::Stream* fi);
321 void Save(dmlc::Stream* fo)
const;
323 void LoadModel(
Json const& in)
override;
324 void SaveModel(
Json* out)
const override;
327 return nodes_ == b.nodes_ && stats_ == b.stats_ &&
328 deleted_nodes_ == b.deleted_nodes_ && param == b.
param;
335 template <
typename Func>
void WalkTree(Func func)
const {
336 std::stack<bst_node_t> nodes;
339 while (!nodes.empty()) {
340 auto nidx = nodes.top();
345 auto left =
self[nidx].LeftChild();
346 auto right =
self[nidx].RightChild();
361 bool Equal(
const RegTree& b)
const;
381 bool default_left,
bst_float base_weight,
383 bst_float loss_change,
float sum_hess,
float left_sum,
385 bst_node_t leaf_right_child = kInvalidNodeId) {
386 int pleft = this->AllocNode();
387 int pright = this->AllocNode();
388 auto &node = nodes_[nid];
389 CHECK(node.IsLeaf());
390 node.SetLeftChild(pleft);
391 node.SetRightChild(pright);
392 nodes_[node.LeftChild()].SetParent(nid,
true);
393 nodes_[node.RightChild()].SetParent(nid,
false);
394 node.SetSplit(split_index, split_value,
397 nodes_[pleft].SetLeaf(left_leaf_weight, leaf_right_child);
398 nodes_[pright].SetLeaf(right_leaf_weight, leaf_right_child);
400 this->Stat(nid) = {loss_change, sum_hess, base_weight};
401 this->Stat(pleft) = {0.0f, left_sum, left_leaf_weight};
402 this->Stat(pright) = {0.0f, right_sum, right_leaf_weight};
411 while (!nodes_[nid].IsRoot()) {
413 nid = nodes_[nid].Parent();
422 if (nodes_[nid].IsLeaf())
return 0;
423 return std::max(MaxDepth(nodes_[nid].LeftChild())+1,
424 MaxDepth(nodes_[nid].RightChild())+1);
452 void Init(
size_t size);
479 bool IsMissing(
size_t i)
const;
490 std::vector<Entry> data_;
497 int GetLeafIndex(
const FVec& feat)
const;
506 bst_float* out_contribs,
int condition = 0,
507 unsigned condition_feature = 0)
const;
523 unsigned unique_depth, PathElement* parent_unique_path,
525 int parent_feature_index,
int condition,
526 unsigned condition_feature,
bst_float condition_fraction)
const;
541 inline int GetNext(
int pid,
bst_float fvalue,
bool is_unknown)
const;
551 std::string format)
const;
555 void FillNodeMeanValues();
559 std::vector<Node> nodes_;
561 std::vector<int> deleted_nodes_;
563 std::vector<RTreeNodeStat> stats_;
564 std::vector<bst_float> node_mean_values_;
569 int nid = deleted_nodes_.back();
570 deleted_nodes_.pop_back();
576 CHECK_LT(param.
num_nodes, std::numeric_limits<int>::max())
577 <<
"number of nodes in the tree exceed 2^31";
583 void DeleteNode(
int nid) {
585 auto pid = (*this)[nid].Parent();
586 if (nid == (*
this)[pid].LeftChild()) {
587 (*this)[pid].SetLeftChild(kInvalidNodeId);
589 (*this)[pid].SetRightChild(kInvalidNodeId);
592 deleted_nodes_.push_back(nid);
593 nodes_[nid].MarkDelete();
600 Entry e; e.flag = -1;
602 std::fill(data_.begin(), data_.end(), e);
606 for (
auto const& entry : inst) {
607 if (entry.index >= data_.size()) {
610 data_[entry.index].fvalue = entry.fvalue;
615 for (
auto const& entry : inst) {
616 if (entry.index >= data_.size()) {
619 data_[entry.index].flag = -1;
628 return data_[i].fvalue;
632 return data_[i].flag == -1;
637 while (!(*
this)[nid].IsLeaf()) {
638 unsigned split_index = (*this)[nid].SplitIndex();
646 bst_float split_value = (*this)[pid].SplitCond();
648 return (*
this)[pid].DefaultChild();
650 if (fvalue < split_value) {
651 return (*
this)[pid].LeftChild();
653 return (*
this)[pid].RightChild();
658 #endif // XGBOOST_TREE_MODEL_H_ int NumExtraNodes() const
number of extra nodes besides the root
Definition: tree_model.h:435
RTreeNodeStat(float loss_chg, float sum_hess, float weight)
Definition: tree_model.h:94
int deprecated_num_roots
(Deprecated) number of start root
Definition: tree_model.h:37
int GetNext(int pid, bst_float fvalue, bool is_unknown) const
get next position of the tree given current pid
Definition: tree_model.h:645
Node(int32_t cleft, int32_t cright, int32_t parent, uint32_t split_ind, float split_cond, bool default_left)
Definition: tree_model.h:121
XGBOOST_DEVICE int RightChild() const
index of right child
Definition: tree_model.h:133
void ExpandNode(int nid, unsigned split_index, bst_float split_value, bool default_left, bst_float base_weight, bst_float left_leaf_weight, bst_float right_leaf_weight, bst_float loss_change, float sum_hess, float left_sum, float right_sum, bst_node_t leaf_right_child=kInvalidNodeId)
Expands a leaf node into two additional leaf nodes.
Definition: tree_model.h:380
float bst_float
float type, used for storing statistics
Definition: base.h:111
int deprecated_max_depth
maximum depth, this is a statistics of the tree
Definition: tree_model.h:43
XGBOOST_DEVICE void SetLeaf(bst_float value, int right=kInvalidNodeId)
set the leaf value of the node
Definition: tree_model.h:206
XGBOOST_DEVICE bst_float LeafValue() const
Definition: tree_model.h:153
XGBOOST_DEVICE void SetParent(int pidx, bool is_left_child=true)
Definition: tree_model.h:220
The input data structure of xgboost.
void Fill(const SparsePage::Inst &inst)
fill the vector with sparse vector
Definition: tree_model.h:605
const Node & operator[](int nid) const
get node given nid
Definition: tree_model.h:296
Defines the abstract interface for different components in XGBoost.
bst_float GetFvalue(size_t i) const
get ith value
Definition: tree_model.h:627
XGBOOST_DEVICE void SetRightChild(int nid)
set the right child
Definition: tree_model.h:185
XGBOOST_DEVICE unsigned SplitIndex() const
feature index of split condition
Definition: tree_model.h:141
int leaf_child_cnt
number of child that is leaf node known up to now
Definition: tree_model.h:91
int GetDepth(int nid) const
get current depth
Definition: tree_model.h:409
Feature map data structure to help text model dump. TODO(tqchen) consider make it even more lightweig...
Definition: feature_map.h:22
XGBOOST_DEVICE bool DefaultLeft() const
when feature is unknown, whether goes to left child
Definition: tree_model.h:145
bst_float base_weight
weight of current node
Definition: tree_model.h:89
define regression tree to be the most common tree model. This is the data structure used in xgboost's...
Definition: tree_model.h:106
TreeParam()
constructor
Definition: tree_model.h:54
int size_leaf_vector
leaf vector size, used for vector tree used to store more than one dimensional information in tree ...
Definition: tree_model.h:50
span class implementation, based on ISO++20 span<T>. The interface should be the same.
Definition: span.h:126
bool operator==(const RTreeNodeStat &b) const
Definition: tree_model.h:96
node statistics used in regression tree
Definition: tree_model.h:83
bst_float sum_hess
sum of hessian values, used to measure coverage of data
Definition: tree_model.h:87
void ChangeToLeaf(int rid, bst_float value)
change a non leaf node to a leaf node, delete its children
Definition: tree_model.h:255
XGBOOST_DEVICE Node()
Definition: tree_model.h:116
meta parameters of the tree
Definition: tree_model.h:35
bool operator==(const RegTree &b) const
Definition: tree_model.h:326
int num_deleted
number of deleted nodes
Definition: tree_model.h:41
int MaxDepth(int nid) const
get maximum depth
Definition: tree_model.h:421
int num_nodes
total number of nodes
Definition: tree_model.h:39
XGBOOST_DEVICE void SetLeftChild(int nid)
set the left child
Definition: tree_model.h:178
int32_t bst_node_t
Type for tree node index.
Definition: base.h:123
TreeParam param
model parameter
Definition: tree_model.h:279
bool operator==(const Node &b) const
Definition: tree_model.h:224
int GetLeafIndex(const FVec &feat) const
get the leaf index
Definition: tree_model.h:635
bool operator==(const TreeParam &b) const
Definition: tree_model.h:74
const RTreeNodeStat & Stat(int nid) const
get node statistics given nid
Definition: tree_model.h:308
void Init(size_t size)
initialize the vector with size vector
Definition: tree_model.h:599
#define XGBOOST_DEVICE
Tag function as usable by device.
Definition: base.h:84
Feature map data structure to help visualization and model dump.
XGBOOST_DEVICE bool IsRoot() const
whether current node is root
Definition: tree_model.h:173
const std::vector< Node > & GetNodes() const
get const reference to nodes
Definition: tree_model.h:301
namespace of xgboost
Definition: base.h:102
static constexpr bst_node_t kInvalidNodeId
Definition: tree_model.h:109
tree node
Definition: tree_model.h:114
int MaxDepth()
get maximum depth
Definition: tree_model.h:430
XGBOOST_DEVICE void SetSplit(unsigned split_index, SplitCondT split_cond, bool default_left=false)
set split condition of current node
Definition: tree_model.h:194
XGBOOST_DEVICE int DefaultChild() const
index of default child when feature is missing
Definition: tree_model.h:137
defines configuration macros of xgboost.
XGBOOST_DEVICE void MarkDelete()
mark that this node is deleted
Definition: tree_model.h:212
XGBOOST_DEVICE bool IsLeftChild() const
whether current node is left child
Definition: tree_model.h:165
void WalkTree(Func func) const
Definition: tree_model.h:335
size_t Size() const
returns the size of the feature vector
Definition: tree_model.h:623
XGBOOST_DEVICE bool IsLeaf() const
whether current node is leaf node
Definition: tree_model.h:149
RegTree()
constructor
Definition: tree_model.h:281
XGBOOST_DEVICE bool IsDeleted() const
whether this node is deleted
Definition: tree_model.h:169
int num_feature
number of features used for tree construction
Definition: tree_model.h:45
void Drop(const SparsePage::Inst &inst)
drop the trace after fill, must be called after fill.
Definition: tree_model.h:614
Data structure representing JSON format.
Definition: json.h:326
int reserved[31]
reserved part, make sure alignment works for 64bit
Definition: tree_model.h:52
void CollapseToLeaf(int rid, bst_float value)
collapse a non leaf node to a leaf node, delete its children
Definition: tree_model.h:267
Node & operator[](int nid)
get node given nid
Definition: tree_model.h:292
bst_float SplitCondT
Definition: tree_model.h:108
XGBOOST_DEVICE int LeftChild() const
index of left child
Definition: tree_model.h:129
RTreeNodeStat & Stat(int nid)
get node statistics given nid
Definition: tree_model.h:304
XGBOOST_DEVICE void Reuse()
Reuse this deleted node.
Definition: tree_model.h:216
bst_float loss_chg
loss change caused by current split
Definition: tree_model.h:85
XGBOOST_DEVICE int Parent() const
get parent of the node
Definition: tree_model.h:161
XGBOOST_DEVICE SplitCondT SplitCond() const
Definition: tree_model.h:157
bool IsMissing(size_t i) const
check whether i-th entry is missing
Definition: tree_model.h:631
dense feature vector that can be taken by RegTree and can be construct from sparse feature vector...
Definition: tree_model.h:447
DMLC_DECLARE_PARAMETER(TreeParam)
Definition: tree_model.h:63