7 #ifndef XGBOOST_COMMON_ROW_SET_H_ 8 #define XGBOOST_COMMON_ROW_SET_H_ 26 const size_t*
end{
nullptr};
36 inline size_t Size()
const {
46 inline std::vector<Elem>::const_iterator
begin()
const {
47 return elem_of_each_node_.begin();
50 inline std::vector<Elem>::const_iterator
end()
const {
51 return elem_of_each_node_.end();
57 CHECK(e.
begin !=
nullptr)
58 <<
"access element that is not in the set";
70 elem_of_each_node_.clear();
74 CHECK_EQ(elem_of_each_node_.size(), 0U);
76 if (row_indices_.empty()) {
82 const size_t*
begin =
reinterpret_cast<size_t*
>(20);
84 elem_of_each_node_.emplace_back(
Elem(begin, end, 0));
88 const size_t*
begin = dmlc::BeginPtr(row_indices_);
89 const size_t*
end = dmlc::BeginPtr(row_indices_) + row_indices_.size();
90 elem_of_each_node_.emplace_back(
Elem(begin, end, 0));
93 std::vector<size_t>*
Data() {
return &row_indices_; }
96 unsigned left_node_id,
97 unsigned right_node_id,
101 CHECK(e.
begin !=
nullptr);
102 size_t* all_begin = dmlc::BeginPtr(row_indices_);
103 size_t*
begin = all_begin + (e.
begin - all_begin);
105 CHECK_EQ(n_left + n_right, e.
Size());
106 CHECK_LE(begin + n_left, e.
end);
107 CHECK_EQ(begin + n_left + n_right, e.
end);
109 if (left_node_id >= elem_of_each_node_.size()) {
110 elem_of_each_node_.resize(left_node_id + 1,
Elem(
nullptr,
nullptr, -1));
112 if (right_node_id >= elem_of_each_node_.size()) {
113 elem_of_each_node_.resize(right_node_id + 1,
Elem(
nullptr,
nullptr, -1));
116 elem_of_each_node_[left_node_id] =
Elem(begin, begin + n_left, left_node_id);
117 elem_of_each_node_[right_node_id] =
Elem(begin + n_left, e.
end, right_node_id);
118 elem_of_each_node_[
node_id] =
Elem(
nullptr,
nullptr, -1);
123 std::vector<size_t> row_indices_;
125 std::vector<Elem> elem_of_each_node_;
134 template<
size_t BlockSize>
137 template<
typename Func>
138 void Init(
const size_t n_tasks,
size_t n_nodes, Func funcNTaks) {
139 left_right_nodes_sizes_.resize(n_nodes);
140 blocks_offsets_.resize(n_nodes+1);
142 blocks_offsets_[0] = 0;
143 for (
size_t i = 1; i < n_nodes+1; ++i) {
144 blocks_offsets_[i] = blocks_offsets_[i-1] + funcNTaks(i-1);
147 if (n_tasks > max_n_tasks_) {
148 mem_blocks_.resize(n_tasks);
149 max_n_tasks_ = n_tasks;
154 const size_t task_idx = GetTaskIdx(nid, begin);
155 return { mem_blocks_.at(task_idx).Left(), end - begin };
159 const size_t task_idx = GetTaskIdx(nid, begin);
160 return { mem_blocks_.at(task_idx).Right(), end - begin };
164 size_t task_idx = GetTaskIdx(nid, begin);
165 mem_blocks_.at(task_idx).n_left = n_left;
169 size_t task_idx = GetTaskIdx(nid, begin);
170 mem_blocks_.at(task_idx).n_right = n_right;
175 return left_right_nodes_sizes_[nid].first;
179 return left_right_nodes_sizes_[nid].second;
185 for (
size_t i = 0; i < blocks_offsets_.size()-1; ++i) {
187 for (
size_t j = blocks_offsets_[i]; j < blocks_offsets_[i+1]; ++j) {
188 mem_blocks_[j].n_offset_left = n_left;
189 n_left += mem_blocks_[j].n_left;
192 for (
size_t j = blocks_offsets_[i]; j < blocks_offsets_[i+1]; ++j) {
193 mem_blocks_[j].n_offset_right = n_left + n_right;
194 n_right += mem_blocks_[j].n_right;
196 left_right_nodes_sizes_[i] = {n_left, n_right};
201 size_t task_idx = GetTaskIdx(nid, begin);
203 size_t* left_result = rows_indexes + mem_blocks_[task_idx].n_offset_left;
204 size_t* right_result = rows_indexes + mem_blocks_[task_idx].n_offset_right;
206 const size_t* left = mem_blocks_[task_idx].Left();
207 const size_t* right = mem_blocks_[task_idx].Right();
209 std::copy_n(left, mem_blocks_[task_idx].n_left, left_result);
210 std::copy_n(right, mem_blocks_[task_idx].n_right, right_result);
215 return blocks_offsets_[nid] + begin / BlockSize;
226 return &left_data_[0];
230 return &right_data_[0];
233 alignas(128)
size_t left_data_[BlockSize];
234 alignas(128)
size_t right_data_[BlockSize];
239 size_t max_n_tasks_ = 0;
246 #endif // XGBOOST_COMMON_ROW_SET_H_ size_t n_offset_left
Definition: row_set.h:222
Definition: row_set.h:135
void AddSplit(unsigned node_id, unsigned left_node_id, unsigned right_node_id, size_t n_left, size_t n_right)
Definition: row_set.h:95
collection of rowset
Definition: row_set.h:19
std::vector< Elem >::const_iterator begin() const
Definition: row_set.h:46
size_t n_left
Definition: row_set.h:219
The input data structure of xgboost.
std::vector< BlockInfo > mem_blocks_
Definition: row_set.h:238
void Init()
Definition: row_set.h:73
span class implementation, based on ISO++20 span<T>. The interface should be the same.
Definition: span.h:126
Elem & operator[](unsigned node_id)
return corresponding element set given the node_id
Definition: row_set.h:63
std::vector< Elem >::const_iterator end() const
Definition: row_set.h:50
common::Span< size_t > GetRightBuffer(int nid, size_t begin, size_t end)
Definition: row_set.h:158
size_t * Left()
Definition: row_set.h:225
int node_id
Definition: row_set.h:27
size_t n_offset_right
Definition: row_set.h:223
size_t GetNLeftElems(int nid) const
Definition: row_set.h:174
void CalculateRowOffsets()
Definition: row_set.h:184
namespace of xgboost
Definition: base.h:102
Definition: row_set.h:218
data structure to store an instance set, a subset of rows (instances) associated with a particular no...
Definition: row_set.h:24
void SetNRightElems(int nid, size_t begin, size_t end, size_t n_right)
Definition: row_set.h:168
void SetNLeftElems(int nid, size_t begin, size_t end, size_t n_left)
Definition: row_set.h:163
size_t * Right()
Definition: row_set.h:229
const size_t * begin
Definition: row_set.h:25
std::vector< size_t > blocks_offsets_
Definition: row_set.h:237
const Elem & operator[](unsigned node_id) const
return corresponding element set given the node_id
Definition: row_set.h:55
const size_t * end
Definition: row_set.h:26
void MergeToArray(int nid, size_t begin, size_t *rows_indexes)
Definition: row_set.h:200
std::vector< std::pair< size_t, size_t > > left_right_nodes_sizes_
Definition: row_set.h:236
std::vector< size_t > left
Definition: row_set.h:42
size_t GetNRightElems(int nid) const
Definition: row_set.h:178
std::vector< size_t > right
Definition: row_set.h:43
void Init(const size_t n_tasks, size_t n_nodes, Func funcNTaks)
Definition: row_set.h:138
std::vector< size_t > * Data()
Definition: row_set.h:93
size_t GetTaskIdx(int nid, size_t begin)
Definition: row_set.h:214
void Clear()
Definition: row_set.h:69
Elem(const size_t *begin, const size_t *end, int node_id=-1)
Definition: row_set.h:31
common::Span< size_t > GetLeftBuffer(int nid, size_t begin, size_t end)
Definition: row_set.h:153
size_t Size() const
Definition: row_set.h:36
size_t n_right
Definition: row_set.h:220