pip install ydf -U
import ydf
import pandas as pd
What is Model Interpretation?¶
Model interpretation (or model understanding) is the process of explaining how a machine learning model works. These methods are crucial for validating and debugging your model, ensuring it aligns with domain knowledge, and building trust in its decisions.
This is different from prediction interpretation, which focuses on explaining a single prediction. To learn how to explain individual predictions, see the prediction understanding tutorial.
YDF offers several powerful tools for model interpretation:
model.describe(): Provides a high-level summary of the model, including its input features, variable importances, and training logs.model.analyze(): Performs a deep analysis of the model's behavior on a dataset, generating rich visualizations like Partial Dependence Plots and SHAP-based variable importances.model.print_tree()andmodel.plot_tree(): Allows you to visualize the structure of individual decision trees, which is most useful for simple models.
Dataset and Model Training¶
We'll use the "Adult" census dataset to predict whether an individual's income is over $50k.
# Load the training and testing datasets
ds_path = "https://raw.githubusercontent.com/google/yggdrasil-decision-forests/main/yggdrasil_decision_forests/test_data/dataset"
train_ds = pd.read_csv(f"{ds_path}/adult_train.csv")
test_ds = pd.read_csv(f"{ds_path}/adult_test.csv")
# Display the first few rows
train_ds.head(3)
| age | workclass | fnlwgt | education | education_num | marital_status | occupation | relationship | race | sex | capital_gain | capital_loss | hours_per_week | native_country | income | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 44 | Private | 228057 | 7th-8th | 4 | Married-civ-spouse | Machine-op-inspct | Wife | White | Female | 0 | 0 | 40 | Dominican-Republic | <=50K |
| 1 | 20 | Private | 299047 | Some-college | 10 | Never-married | Other-service | Not-in-family | White | Female | 0 | 0 | 20 | United-States | <=50K |
| 2 | 40 | Private | 342164 | HS-grad | 9 | Separated | Adm-clerical | Unmarried | White | Female | 0 | 0 | 37 | United-States | <=50K |
Now, let's train a Gradient Boosted Trees model. This is a powerful model composed of many trees, making it a good candidate for advanced interpretation techniques.
# Train a Gradient Boosted Trees model
model = ydf.GradientBoostedTreesLearner(label="income").train(train_ds)
Train model on 22792 examples Model trained in 0:00:01.874314
High-Level Model Summary with describe()¶
The model.describe() method is your first stop for understanding a model. It provides a rich, interactive report.
Key sections of the report include:
- Input Features: A list of all the features the model was trained on.
- Variable Importances: A set of scores that rank features by their importance to the model's predictions. There are several types of importance scores, each giving a different perspective on feature utility. For example,
NUM_NODESshows how often a feature was used in splits across all trees.
Note:
model.describe()only shows variable importances that can be computed without a test dataset. More advanced variable importances such as SHAP values are available in the model analysis report (see below).
- Training Logs: Detailed logs from the training process, which are useful for debugging.
model.describe()
Task : CLASSIFICATION
Label : income
Features (14) : age workclass fnlwgt education education_num marital_status occupation relationship race sex capital_gain capital_loss hours_per_week native_country
Weights : None
Trained with tuner : No
Trained with Feature Selection : No
Model size : 2154 kB
Number of records: 22792 Number of columns: 15 Number of columns by type: CATEGORICAL: 9 (60%) NUMERICAL: 6 (40%) Columns: CATEGORICAL: 9 (60%) 0: "income" CATEGORICAL has-dict vocab-size:3 zero-ood-items most-frequent:"<=50K" 17308 (75.9389%) dtype:DTYPE_BYTES 2: "workclass" CATEGORICAL num-nas:1257 (5.51509%) has-dict vocab-size:8 num-oods:3 (0.0139308%) most-frequent:"Private" 15879 (73.7358%) dtype:DTYPE_BYTES 4: "education" CATEGORICAL has-dict vocab-size:17 zero-ood-items most-frequent:"HS-grad" 7340 (32.2043%) dtype:DTYPE_BYTES 6: "marital_status" CATEGORICAL has-dict vocab-size:8 zero-ood-items most-frequent:"Married-civ-spouse" 10431 (45.7661%) dtype:DTYPE_BYTES 7: "occupation" CATEGORICAL num-nas:1260 (5.52826%) has-dict vocab-size:14 num-oods:4 (0.018577%) most-frequent:"Prof-specialty" 2870 (13.329%) dtype:DTYPE_BYTES 8: "relationship" CATEGORICAL has-dict vocab-size:7 zero-ood-items most-frequent:"Husband" 9191 (40.3256%) dtype:DTYPE_BYTES 9: "race" CATEGORICAL has-dict vocab-size:6 zero-ood-items most-frequent:"White" 19467 (85.4115%) dtype:DTYPE_BYTES 10: "sex" CATEGORICAL has-dict vocab-size:3 zero-ood-items most-frequent:"Male" 15165 (66.5365%) dtype:DTYPE_BYTES 14: "native_country" CATEGORICAL num-nas:407 (1.78571%) has-dict vocab-size:41 num-oods:1 (0.00446728%) most-frequent:"United-States" 20436 (91.2933%) dtype:DTYPE_BYTES NUMERICAL: 6 (40%) 1: "age" NUMERICAL mean:38.6153 min:17 max:90 sd:13.661 dtype:DTYPE_INT64 3: "fnlwgt" NUMERICAL mean:189879 min:12285 max:1.4847e+06 sd:106423 dtype:DTYPE_INT64 5: "education_num" NUMERICAL mean:10.0927 min:1 max:16 sd:2.56427 dtype:DTYPE_INT64 11: "capital_gain" NUMERICAL mean:1081.9 min:0 max:99999 sd:7509.48 dtype:DTYPE_INT64 12: "capital_loss" NUMERICAL mean:87.2806 min:0 max:4356 sd:403.01 dtype:DTYPE_INT64 13: "hours_per_week" NUMERICAL mean:40.3955 min:1 max:99 sd:12.249 dtype:DTYPE_INT64 Terminology: nas: Number of non-available (i.e. missing) values. ood: Out of dictionary. manually-defined: Attribute whose type is manually defined by the user, i.e., the type was not automatically inferred. tokenized: The attribute value is obtained through tokenization. has-dict: The attribute is attached to a string dictionary e.g. a categorical attribute stored as a string. vocab-size: Number of unique values.
The following evaluation is computed on the validation or out-of-bag dataset.
Task: CLASSIFICATION
Label: income
Loss (BINOMIAL_LOG_LIKELIHOOD): 0.576143
Accuracy: 0.868526 CI95[W][0 1]
ErrorRate: : 0.131474
Confusion Table:
truth\prediction
<=50K >50K
<=50K 1557 107
>50K 190 405
Total: 2259
Variable importances measure the importance of an input feature for a model.
1. "age" 0.226664 ################
2. "occupation" 0.219688 #############
3. "education" 0.218052 #############
4. "capital_gain" 0.214876 ############
5. "marital_status" 0.212828 ###########
6. "relationship" 0.205975 #########
7. "fnlwgt" 0.203878 ########
8. "hours_per_week" 0.203735 ########
9. "capital_loss" 0.196549 ######
10. "native_country" 0.190548 ####
11. "workclass" 0.187810 ###
12. "education_num" 0.181204 #
13. "race" 0.180495
14. "sex" 0.177632
1. "age" 26.000000 ################
2. "capital_gain" 26.000000 ################
3. "marital_status" 20.000000 ############
4. "relationship" 17.000000 ##########
5. "education" 15.000000 ########
6. "capital_loss" 14.000000 ########
7. "hours_per_week" 14.000000 ########
8. "fnlwgt" 10.000000 #####
9. "race" 9.000000 #####
10. "education_num" 4.000000 #
11. "sex" 4.000000 #
12. "occupation" 2.000000
13. "workclass" 1.000000
14. "native_country" 1.000000
1. "occupation" 722.000000 ################
2. "fnlwgt" 515.000000 ###########
3. "age" 483.000000 ##########
4. "education" 459.000000 #########
5. "hours_per_week" 339.000000 #######
6. "capital_gain" 325.000000 ######
7. "native_country" 306.000000 ######
8. "capital_loss" 297.000000 ######
9. "relationship" 262.000000 #####
10. "workclass" 245.000000 #####
11. "marital_status" 210.000000 ####
12. "education_num" 88.000000 #
13. "sex" 41.000000
14. "race" 21.000000
1. "relationship" 3018.761866 ################
2. "capital_gain" 2065.521668 ##########
3. "education" 1241.764059 ######
4. "marital_status" 1107.545372 #####
5. "occupation" 1094.359168 #####
6. "education_num" 699.517705 ###
7. "capital_loss" 584.055066 ###
8. "age" 582.292563 ###
9. "hours_per_week" 366.856509 #
10. "native_country" 263.872689 #
11. "fnlwgt" 216.537764 #
12. "workclass" 196.221850 #
13. "sex" 46.986269
14. "race" 5.428727
Those variable importances are computed during training. More, and possibly more informative, variable importances are available when analyzing a model on a test dataset.
Below is the first tree of the model. The model contains 163 trees, which jointly make the prediction. Other trees can be printed with `model.print_tree(tree_idx)` or plotted with `model.plot_tree(tree_idx)`
"relationship" is in [BITMAP] {<OOD>, Husband, Wife} [s:0.036623 n:20533 np:9213 miss:1] ; pred:-4.15883e-09
├─(pos)─ "education" is in [BITMAP] {Bachelors, Masters, Prof-school, Doctorate} [s:0.0343752 n:9213 np:2773 miss:0] ; pred:0.116933
| ├─(pos)─ "capital_gain">=5095.5 [s:0.0125728 n:2773 np:434 miss:0] ; pred:0.272683
| | ├─(pos)─ "occupation" is in [BITMAP] {<OOD>, Prof-specialty, Exec-managerial, Craft-repair, Adm-clerical, Sales, Other-service, Machine-op-inspct, Transport-moving, Handlers-cleaners, ...[2 left]} [s:0.000434532 n:434 np:429 miss:1] ; pred:0.416173
| | | ├─(pos)─ "age">=79.5 [s:0.000449964 n:429 np:5 miss:0] ; pred:0.417414
| | | | ├─(pos)─ pred:0.309737
| | | | └─(neg)─ pred:0.418684
| | | └─(neg)─ pred:0.309737
| | └─(neg)─ "capital_loss">=1782.5 [s:0.0101181 n:2339 np:249 miss:0] ; pred:0.246058
| | ├─(pos)─ "capital_loss">=1989.5 [s:0.00201289 n:249 np:39 miss:0] ; pred:0.406701
| | | ├─(pos)─ pred:0.349312
| | | └─(neg)─ pred:0.417359
| | └─(neg)─ "occupation" is in [BITMAP] {Prof-specialty, Exec-managerial, Sales, Tech-support, Protective-serv} [s:0.0097175 n:2090 np:1688 miss:0] ; pred:0.226919
| | ├─(pos)─ pred:0.253437
| | └─(neg)─ pred:0.11557
| └─(neg)─ "capital_gain">=5095.5 [s:0.0205419 n:6440 np:303 miss:0] ; pred:0.0498685
| ├─(pos)─ "age">=60.5 [s:0.00421502 n:303 np:43 miss:0] ; pred:0.40543
| | ├─(pos)─ "occupation" is in [BITMAP] {Prof-specialty, Exec-managerial, Adm-clerical, Sales, Machine-op-inspct, Transport-moving, Handlers-cleaners} [s:0.0296244 n:43 np:25 miss:0] ; pred:0.317428
| | | ├─(pos)─ pred:0.397934
| | | └─(neg)─ pred:0.205614
| | └─(neg)─ "fnlwgt">=36212.5 [s:1.36643e-16 n:260 np:250 miss:1] ; pred:0.419984
| | ├─(pos)─ pred:0.419984
| | └─(neg)─ pred:0.419984
| └─(neg)─ "occupation" is in [BITMAP] {Prof-specialty, Exec-managerial, Adm-clerical, Sales, Tech-support, Protective-serv} [s:0.0100346 n:6137 np:2334 miss:0] ; pred:0.0323136
| ├─(pos)─ "age">=33.5 [s:0.00939348 n:2334 np:1769 miss:1] ; pred:0.102799
| | ├─(pos)─ pred:0.132992
| | └─(neg)─ pred:0.00826457
| └─(neg)─ "education_num">=8.5 [s:0.00478423 n:3803 np:2941 miss:1] ; pred:-0.0109452
| ├─(pos)─ pred:0.00969668
| └─(neg)─ pred:-0.0813718
└─(neg)─ "capital_gain">=7073.5 [s:0.0143125 n:11320 np:199 miss:0] ; pred:-0.0951681
├─(pos)─ "age">=21.5 [s:0.00807667 n:199 np:194 miss:1] ; pred:0.397823
| ├─(pos)─ "capital_gain">=7565.5 [s:0.00761118 n:194 np:184 miss:0] ; pred:0.405777
| | ├─(pos)─ "capital_gain">=30961.5 [s:0.000242202 n:184 np:20 miss:0] ; pred:0.416988
| | | ├─(pos)─ pred:0.392422
| | | └─(neg)─ pred:0.419984
| | └─(neg)─ "education" is in [BITMAP] {Bachelors, Masters, Assoc-voc, Prof-school} [s:0.16 n:10 np:5 miss:0] ; pred:0.19949
| | ├─(pos)─ pred:0.419984
| | └─(neg)─ pred:-0.0210046
| └─(neg)─ pred:0.0892425
└─(neg)─ "education_num">=12.5 [s:0.00229611 n:11121 np:2199 miss:0] ; pred:-0.10399
├─(pos)─ "age">=31.5 [s:0.00725859 n:2199 np:1263 miss:1] ; pred:-0.0507848
| ├─(pos)─ "education_num">=14.5 [s:0.0110157 n:1263 np:125 miss:0] ; pred:-0.0103552
| | ├─(pos)─ pred:0.16421
| | └─(neg)─ pred:-0.0295298
| └─(neg)─ "capital_loss">=1977 [s:0.00164232 n:936 np:5 miss:0] ; pred:-0.105339
| ├─(pos)─ pred:0.19949
| └─(neg)─ pred:-0.106976
└─(neg)─ "capital_loss">=2218.5 [s:0.000534265 n:8922 np:41 miss:0] ; pred:-0.117103
├─(pos)─ "fnlwgt">=125450 [s:0.0755454 n:41 np:28 miss:1] ; pred:0.0704198
| ├─(pos)─ pred:-0.0328167
| └─(neg)─ pred:0.292776
└─(neg)─ "hours_per_week">=40.5 [s:0.000447024 n:8881 np:1559 miss:0] ; pred:-0.117969
├─(pos)─ pred:-0.0927111
└─(neg)─ pred:-0.123347
In-Depth Analysis with analyze()¶
While describe() gives a static summary, model.analyze() computes how the model behaves on a specific dataset. This is essential for understanding the relationships the model has learned.
The analysis generates Partial Dependence Plots (PDPs), which show how a single feature influences the model's predictions on average, while holding all other features constant. It also calculates additional variable importances, including those based on SHAP.
Note: Model analysis can be computationally expensive on large datasets. Use the
samplingparameter to run a faster analysis on a random subset of your data.
# Analyze the model's behavior on the test dataset
# We use a 10% sample for speed.
analysis = model.analyze(test_ds, sampling=0.1)
# The analysis is interactive in a notebook environment.
# You can also save it to a file:
# analysis.to_file("analysis.html")
Interpreting Simpler Models: The Decision Tree¶
Powerful models like Random Forests and Gradient Boosted Trees are ensembles of many (often hundreds of) trees.
Warning: It is misleading to interpret a multi-tree model by looking at just one of its trees. Each tree is only a small part of the ensemble, and in GBTs, each tree predicts the error of the previous ones, not the final outcome.
To gain intuition about the fundamental structure of your data, it's often helpful to train a simpler, "weaker" model. A single Decision Tree is perfect for this. It will have lower performance but is fully transparent.
Let's train a shallow Decision Tree and visualize it.
# Train a single, shallow decision tree for interpretability
weak_model = ydf.DecisionTreeLearner(label="income", max_depth=4).train(train_ds)
# Get the labels for context: [b'<=50K', b'>50K']
print("Class labels:", weak_model.label_classes())
# You can print the tree as text
print(weak_model.print_tree())
Train model on 22792 examples
Model trained in 0:00:00.718672
Class labels: ['<=50K', '>50K']
'relationship' in ['Husband', 'Wife'] [score=0.10683 missing=False]
├─(pos)─ 'education_num' >= 12.5 [score=0.070882 missing=False]
│ ├─(pos)─ value=[0.2672196177425171, 0.7327803822574829]
│ └─(neg)─ 'capital_gain' >= 5095.5 [score=0.047968 missing=False]
│ ├─(pos)─ value=[0.026402640264026403, 0.9735973597359736]
│ └─(neg)─ value=[0.7032752159035359, 0.2967247840964641]
└─(neg)─ 'capital_gain' >= 7073.5 [score=0.04532 missing=False]
├─(pos)─ value=[0.04020100502512563, 0.9597989949748744]
└─(neg)─ value=[0.950544015825915, 0.04945598417408507]
None
For a classification model, the value at each leaf shows the predicted probability for each class. For example, value=[0.8, 0.2] means the model predicts the first class with 80% probability and the second with 20%.
A text printout is useful, but a plot is often clearer. plot_tree() generates a visual representation where you can trace the decision paths.
# Or plot the tree for a visual representation
weak_model.plot_tree()