XGBoost
and LightGBM
. As well as the reader will get acquainted with the leaves
library for Go, which allows you to make predictions for tree ensembles, without using the original libraries C API. double pred = 0.0; for (auto& tree: trees) pred += tree->Predict(feature_values);
C++
, since XGBoost
and LightGBM
are written in this language. I will omit irrelevant details and try to give the most concise code.Predict
, and how the tree data structure is organized.XGBoost
there are several classes (in the sense of OOP) of trees. We will talk about RegTree
(see include/xgboost/tree_model.h
), which is the main word from the documentation. If you leave only the details that are important for predictions, the members of the class look as simple as possible: class RegTree { // vector of nodes std::vector<Node> nodes_; };
GetNext
rule is implemented in the GetNext
function. The code is slightly modified, without affecting the result of the calculations: // get next position of the tree given current pid int RegTree::GetNext(int pid, float fvalue, bool is_unknown) const { const auto& node = nodes_[pid] float split_value = node.info_.split_cond; if (is_unknown) { return node.DefaultLeft() ? node.cleft_ : node.cright_; } else { if (fvalue < split_value) { return node.cleft_; } else { return node.cright_; } } }
RegTree
works only with real signs (type float
)Node
class. It contains the local tree structure, the decision rule and the leaf value: class Node { public: // feature index of split condition unsigned SplitIndex() const { return sindex_ & ((1U << 31) - 1U); } // when feature is unknown, whether goes to left child bool DefaultLeft() const { return (sindex_ >> 31) != 0; } // whether current node is leaf node bool IsLeaf() const { return cleft_ == -1; } private: // in leaf node, we have weights, in non-leaf nodes, we have split condition union Info { float leaf_value; float split_cond; } info_; // pointer to left, right int cleft_, cright_; // split feature index, left split or right split depends on the highest bit unsigned sindex_{0}; };
cleft_ = -1
info_
field info_
represented as a union
, i.e. two data types (in this case, the same) divide one section of memory depending on the type of nodesindex_
is responsible for where the object goes, for which the attribute value is omittedRegTree::Predict
method to getting an answer, I’ll provide the missing two functions: float RegTree::Predict(const RegTree::FVec& feat, unsigned root_id) const { int pid = this->GetLeafIndex(feat, root_id); return nodes_[pid].leaf_value; } int RegTree::GetLeafIndex(const RegTree::FVec& feat, unsigned root_id) const { auto pid = static_cast<int>(root_id); while (!nodes_[pid].IsLeaf()) { unsigned split_index = nodes_[pid].SplitIndex(); pid = this->GetNext(pid, feat.Fvalue(split_index), feat.IsMissing(split_index)); } return pid; }
GetLeafIndex
function GetLeafIndex
we loop down through the nodes of the tree until we get to the leaf.include/LightGBM/tree.h
) contains arrays of values, where the node number is used as the index. The values in the leaves are also stored in separate arrays. class Tree { // Number of current leaves int num_leaves_; // A non-leaf node's left child std::vector<int> left_child_; // A non-leaf node's right child std::vector<int> right_child_; // A non-leaf node's split feature, the original index std::vector<int> split_feature_; //A non-leaf node's split threshold in feature value std::vector<double> threshold_; std::vector<int> cat_boundaries_; std::vector<uint32_t> cat_threshold_; // Store the information for categorical feature handle and mising value handle. std::vector<int8_t> decision_type_; // Output of leaves std::vector<double> leaf_value_; };
LightGBM
supports categorical features. Support is provided using a cat_threshold_
stored in cat_threshold_
for all nodes. The cat_boundaries_
stores to which node which part of the bit field corresponds to. The threshold_
field for the categorical case is converted to an int
and corresponds to the index in cat_boundaries_
to search for the beginning of the bit field. int CategoricalDecision(double fval, int node) const { uint8_t missing_type = GetMissingType(decision_type_[node]); int int_fval = static_cast<int>(fval); if (int_fval < 0) { return right_child_[node];; } else if (std::isnan(fval)) { // NaN is always in the right if (missing_type == 2) { return right_child_[node]; } int_fval = 0; } int cat_idx = static_cast<int>(threshold_[node]); if (FindInBitset(cat_threshold_.data() + cat_boundaries_[cat_idx], cat_boundaries_[cat_idx + 1] - cat_boundaries_[cat_idx], int_fval)) { return left_child_[node]; } return right_child_[node]; }
missing_type
value NaN
automatically lowers the solution along the right branch of the tree. Otherwise, NaN
is replaced with 0. Finding a value in a bit field is quite simple: bool FindInBitset(const uint32_t* bits, int n, int pos) { int i1 = pos / 32; if (i1 >= n) { return false; } int i2 = pos % 32; return (bits[i1] >> i2) & 1; }
int_fval=42
checked whether the 41st (numbering with 0) bits are set in the array.LightGBM
to the LightGBM
and accepted them .XGBoost
, and I’ll skip this for short.XGBoost
and LightGBM
very powerful libraries for building gradient LightGBM
models on decision trees. To use them in the backend service, where machine learning algorithms are necessary, it is necessary to solve the following tasks:Go
is a popular language. XGBoost
or LightGBM
through C API and cgo is not the easiest solution - the program builds up, due to careless handling you can catch SIGTERM
, problems with the number of system streams (OpenMP inside libraries vs go runtime).Go
for predictions using models built in XGBoost
or LightGBM
. It is called leaves
.LightGBM
modelsXGBoost
modelsGo
that loads the model from disk and displays the prediction on the screen: package main import ( "bufio" "fmt" "os" "github.com/dmitryikh/leaves" ) func main() { // 1. path := "lightgbm_model.txt" reader, err := os.Open(path) if err != nil { panic(err) } defer reader.Close() // 2. LightGBM model, err := leaves.LGEnsembleFromReader(bufio.NewReader(reader)) if err != nil { panic(err) } // 3. ! fvals := []float64{1.0, 2.0, 3.0} p := model.Predict(fvals, 0) fmt.Printf("Prediction for %v: %f\n", fvals, p) }
XGBoost
model XGBoost
simply call the leaves.XGEnsembleFromReader
method instead of the one above. Predictions can be made in batches by calling the PredictDense
or model.PredictCSR
. More usage scenarios can be found in the tests for leaves .Go
language is slower than C++
(mainly due to the heavier runtime and runtime checks), a number of optimizations have resulted in prediction speeds comparable to the C API call of the original libraries.XGBoost
and LightGBM
. As you can see, the basic constructs are quite simple, and I encourage readers to take advantage of open source — to study the code when there are questions about how it works.leaves
you can quite simply use the leading edge solutions in machine learning in your production environment, practically not losing in speed compared to the original C ++ implementations.Source: https://habr.com/ru/post/423495/
All Articles