Inspecting trees¶
Setup¶
pip install ydf -U
import ydf
import numpy as np
What does it mean to inspect trees?¶
A decision forest model, such as Random Forest or Gradient Boosted Decision Trees, is a collection of decision trees. A decision tree has "internal nodes" (i.e. nodes with child nodes) and "leaf nodes". Using the get_tree
and print_tree
methods, you can inspect the structure of the trees, the conditions and the leaf values.
In this notebook, we train a simple CART model on a synthetic dataset and inspect its tree structure.
Synthetic dataset¶
Our dataset is composed of two input features and six examples.
dataset = {
"x1": np.array([0, 0, 0, 1, 1, 1]),
"x2": np.array([1, 1, 0, 0, 1, 1]),
"y": np.array([0, 0, 0, 0, 1, 1]),
}
dataset
{'x1': array([0, 0, 0, 1, 1, 1]), 'x2': array([1, 1, 0, 0, 1, 1]), 'y': array([0, 0, 0, 0, 1, 1])}
Training a model¶
model = ydf.CartLearner(label="y", min_examples=1, task=ydf.Task.REGRESSION).train(dataset)
model.describe()
Train model on 6 examples Model trained in 0:00:00.000728
Task : REGRESSION
Label : y
Features (2) : x1 x2
Weights : None
Trained with tuner : No
Model size : 3 kB
Number of records: 6 Number of columns: 3 Number of columns by type: NUMERICAL: 3 (100%) Columns: NUMERICAL: 3 (100%) 0: "y" NUMERICAL mean:0.333333 min:0 max:1 sd:0.471405 1: "x1" NUMERICAL mean:0.5 min:0 max:1 sd:0.5 2: "x2" NUMERICAL mean:0.666667 min:0 max:1 sd:0.471405 Terminology: nas: Number of non-available (i.e. missing) values. ood: Out of dictionary. manually-defined: Attribute whose type is manually defined by the user, i.e., the type was not automatically inferred. tokenized: The attribute value is obtained through tokenization. has-dict: The attribute is attached to a string dictionary e.g. a categorical attribute stored as a string. vocab-size: Number of unique values.
The following evaluation is computed on the validation or out-of-bag dataset.
The Random Forest does not have out-of-bag evaluation training logs. Train the model with compute_oob_performances=True to compute the training logs. Make sure the training logs have not been removed with pure_serving_model=True.
Variable importances measure the importance of an input feature for a model.
1. "x1" 1.000000 ################ 2. "x2" 0.500000
1. "x1" 1.000000
1. "x1" 1.000000 2. "x2" 1.000000
1. "x1" 0.666667 2. "x2" 0.666667
Those variable importances are computed during training. More, and possibly more informative, variable importances are available when analyzing a model on a test dataset.
Tree #0: "x1">=0.5 [s:0.111111 n:6 np:3 miss:1] ; pred:0.333333 ├─(pos)─ "x2">=0.5 [s:0.222222 n:3 np:2 miss:1] ; pred:0.666667 | ├─(pos)─ pred:1 | └─(neg)─ pred:0 └─(neg)─ pred:0
Plotting the model¶
The tree of the model is visible in the "structure" tab of model.describe()
. You can also print trees with the print_tree
method.
model.print_tree()
'x1' >= 0.5 [score=0.11111 missing=True] ├─(pos)─ 'x2' >= 0.5 [score=0.22222 missing=True] │ ├─(pos)─ value=1 sd=0 │ └─(neg)─ value=0 sd=0 └─(neg)─ value=0 sd=0
Accessing the tree structure¶
The get_tree
and get_all_trees
methods give access the structure of the trees programmatically.
Note: A CART model only has one tree, so the tree_idx
argument is set to 0
. For models with multiple trees, the number of trees is available with model.num_trees()
.
tree = model.get_tree(tree_idx=0)
tree
Tree(root=NonLeaf(value=RegressionValue(num_examples=6.0, value=0.3333333432674408, standard_deviation=0.4714045207910317), condition=NumericalHigherThanCondition(missing=True, score=0.1111111119389534, attribute=1, threshold=0.5), pos_child=NonLeaf(value=RegressionValue(num_examples=3.0, value=0.6666666865348816, standard_deviation=0.4714045207910317), condition=NumericalHigherThanCondition(missing=True, score=0.2222222238779068, attribute=2, threshold=0.5), pos_child=Leaf(value=RegressionValue(num_examples=2.0, value=1.0, standard_deviation=0.0)), neg_child=Leaf(value=RegressionValue(num_examples=1.0, value=0.0, standard_deviation=0.0))), neg_child=Leaf(value=RegressionValue(num_examples=3.0, value=0.0, standard_deviation=0.0))))
Do you recognize the structure of the tree printed above? You can access parts of the tree. For example, you can access the condition on x2
:
tree.root.pos_child.condition
NumericalHigherThanCondition(missing=True, score=0.2222222238779068, attribute=2, threshold=0.5)
To show the tree in a more readable form, you can use the pretty
function.
print(tree.pretty(model.data_spec()))
'x1' >= 0.5 [score=0.11111 missing=True] ├─(pos)─ 'x2' >= 0.5 [score=0.22222 missing=True] │ ├─(pos)─ value=1 sd=0 │ └─(neg)─ value=0 sd=0 └─(neg)─ value=0 sd=0