🧭 Getting started¶
Decision Forests (DFs) are a family of machine learning algorithms used for classification, regression, ranking, uplifting and anomaly detection. As the name implies, DFs are constructed from a collection of decision trees. The two most popular DF training algorithms today are Random Forests and Gradient Boosted Decision Trees.
Yggdrasil Decision Forests (YDF) is a comprehensive library for training, evaluating, interpreting, and serving these models. YDF is available in several languages, including Python, C++, and CLI. It's also integrated into TensorFlow as TensorFlow Decision Forests. This notebook will walk you through the Python API, which is the recommended way to get started with YDF.
For the complete API Reference and more tutorials, check out the YDF website.
Install YDF¶
pip install ydf -U
Import libraries¶
import ydf # Yggdrasil Decision Forests
import pandas as pd # Used for loading and manipulating small datasets
Download and load dataset¶
We'll use the classic "Adult" dataset for this tutorial. The task is binary classification: predict whether an individual's income is >50k or <=50k based on other numerical and categorical features. This dataset also contains missing values, which YDF handles automatically.
ds_path = "https://raw.githubusercontent.com/google/yggdrasil-decision-forests/main/yggdrasil_decision_forests/test_data/dataset"
# Download and load the dataset into Pandas DataFrames
train_ds = pd.read_csv(f"{ds_path}/adult_train.csv")
test_ds = pd.read_csv(f"{ds_path}/adult_test.csv")
# Display the first 5 rows of the training data
train_ds.head(5)
| age | workclass | fnlwgt | education | education_num | marital_status | occupation | relationship | race | sex | capital_gain | capital_loss | hours_per_week | native_country | income | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 44 | Private | 228057 | 7th-8th | 4 | Married-civ-spouse | Machine-op-inspct | Wife | White | Female | 0 | 0 | 40 | Dominican-Republic | <=50K |
| 1 | 20 | Private | 299047 | Some-college | 10 | Never-married | Other-service | Not-in-family | White | Female | 0 | 0 | 20 | United-States | <=50K |
| 2 | 40 | Private | 342164 | HS-grad | 9 | Separated | Adm-clerical | Unmarried | White | Female | 0 | 0 | 37 | United-States | <=50K |
| 3 | 30 | Private | 361742 | Some-college | 10 | Married-civ-spouse | Exec-managerial | Husband | White | Male | 0 | 0 | 50 | United-States | <=50K |
| 4 | 67 | Self-emp-inc | 171564 | HS-grad | 9 | Married-civ-spouse | Prof-specialty | Wife | White | Female | 20051 | 0 | 30 | England | >50K |
Train a model¶
Let's train a Gradient Boosted Trees model using the default hyper-parameters.
model = ydf.GradientBoostedTreesLearner(label="income").train(train_ds)
Train model on 22792 examples Model trained in 0:00:01.791216
Key points
- YDF distinguishes between learning algorithms (called learners, like
GradientBoostedTreesLearner) and trained models. You'll see the benefits of this distinction in more advanced examples. - The only required parameter for a learner is the
label. All other hyper-parameters have sensible defaults. - Since we didn't specify the input features, all columns except for the label are automatically used as inputs. YDF detects feature types (e.g., numerical, categorical) and handles them appropriately, including those with missing values.
- By default, learners train a classification model. You can specify other
tasks like regression or ranking using the task parameter (e.g.,
task=ydf.Task.REGRESSION). - Training logs can be viewed live by setting
verbose=2in the learner. After training, you can access them withmodel.describe(). - A validation dataset was not provided. In this scenario, learners like
GradientBoostedTreesLearnerautomatically set aside a portion of the training data for validation. Other learners, likeRandomForestLearner, don't require a validation set and use all the data for training.
Inspecting the Model¶
The model.describe() method provides an overview of your model, including:
- Model: The model's task, input features, and size.
- Dataspec: Statistics about each input feature.
- Training: The training and validation loss and performance metrics.
- Variable Importance: A ranking of the features that are most influential for the model.
- Structure: A plot of the first tree of the model.
- Tuning: Logs from hyper-parameter tuning (if enabled).
model.describe()
Task : CLASSIFICATION
Label : income
Features (14) : age workclass fnlwgt education education_num marital_status occupation relationship race sex capital_gain capital_loss hours_per_week native_country
Weights : None
Trained with tuner : No
Model size : 2206 kB
Number of records: 22792 Number of columns: 15 Number of columns by type: CATEGORICAL: 9 (60%) NUMERICAL: 6 (40%) Columns: CATEGORICAL: 9 (60%) 0: "income" CATEGORICAL has-dict vocab-size:3 zero-ood-items most-frequent:"<=50K" 17308 (75.9389%) dtype:DTYPE_BYTES 2: "workclass" CATEGORICAL num-nas:1257 (5.51509%) has-dict vocab-size:8 num-oods:3 (0.0139308%) most-frequent:"Private" 15879 (73.7358%) dtype:DTYPE_BYTES 4: "education" CATEGORICAL has-dict vocab-size:17 zero-ood-items most-frequent:"HS-grad" 7340 (32.2043%) dtype:DTYPE_BYTES 6: "marital_status" CATEGORICAL has-dict vocab-size:8 zero-ood-items most-frequent:"Married-civ-spouse" 10431 (45.7661%) dtype:DTYPE_BYTES 7: "occupation" CATEGORICAL num-nas:1260 (5.52826%) has-dict vocab-size:14 num-oods:4 (0.018577%) most-frequent:"Prof-specialty" 2870 (13.329%) dtype:DTYPE_BYTES 8: "relationship" CATEGORICAL has-dict vocab-size:7 zero-ood-items most-frequent:"Husband" 9191 (40.3256%) dtype:DTYPE_BYTES 9: "race" CATEGORICAL has-dict vocab-size:6 zero-ood-items most-frequent:"White" 19467 (85.4115%) dtype:DTYPE_BYTES 10: "sex" CATEGORICAL has-dict vocab-size:3 zero-ood-items most-frequent:"Male" 15165 (66.5365%) dtype:DTYPE_BYTES 14: "native_country" CATEGORICAL num-nas:407 (1.78571%) has-dict vocab-size:41 num-oods:1 (0.00446728%) most-frequent:"United-States" 20436 (91.2933%) dtype:DTYPE_BYTES NUMERICAL: 6 (40%) 1: "age" NUMERICAL mean:38.6153 min:17 max:90 sd:13.661 dtype:DTYPE_INT64 3: "fnlwgt" NUMERICAL mean:189879 min:12285 max:1.4847e+06 sd:106423 dtype:DTYPE_INT64 5: "education_num" NUMERICAL mean:10.0927 min:1 max:16 sd:2.56427 dtype:DTYPE_INT64 11: "capital_gain" NUMERICAL mean:1081.9 min:0 max:99999 sd:7509.48 dtype:DTYPE_INT64 12: "capital_loss" NUMERICAL mean:87.2806 min:0 max:4356 sd:403.01 dtype:DTYPE_INT64 13: "hours_per_week" NUMERICAL mean:40.3955 min:1 max:99 sd:12.249 dtype:DTYPE_INT64 Terminology: nas: Number of non-available (i.e. missing) values. ood: Out of dictionary. manually-defined: Attribute whose type is manually defined by the user, i.e., the type was not automatically inferred. tokenized: The attribute value is obtained through tokenization. has-dict: The attribute is attached to a string dictionary e.g. a categorical attribute stored as a string. vocab-size: Number of unique values.
The following evaluation is computed on the validation or out-of-bag dataset.
Task: CLASSIFICATION
Label: income
Loss (BINOMIAL_LOG_LIKELIHOOD): 0.576143
Accuracy: 0.868526 CI95[W][0 1]
ErrorRate: : 0.131474
Confusion Table:
truth\prediction
<=50K >50K
<=50K 1557 107
>50K 190 405
Total: 2259
Variable importances measure the importance of an input feature for a model.
1. "age" 0.226664 ################
2. "occupation" 0.219688 #############
3. "education" 0.218052 #############
4. "capital_gain" 0.214876 ############
5. "marital_status" 0.212828 ###########
6. "relationship" 0.205975 #########
7. "fnlwgt" 0.203878 ########
8. "hours_per_week" 0.203735 ########
9. "capital_loss" 0.196549 ######
10. "native_country" 0.190548 ####
11. "workclass" 0.187810 ###
12. "education_num" 0.181204 #
13. "race" 0.180495
14. "sex" 0.177632
1. "age" 26.000000 ################
2. "capital_gain" 26.000000 ################
3. "marital_status" 20.000000 ############
4. "relationship" 17.000000 ##########
5. "education" 15.000000 ########
6. "capital_loss" 14.000000 ########
7. "hours_per_week" 14.000000 ########
8. "fnlwgt" 10.000000 #####
9. "race" 9.000000 #####
10. "education_num" 4.000000 #
11. "sex" 4.000000 #
12. "occupation" 2.000000
13. "workclass" 1.000000
14. "native_country" 1.000000
1. "occupation" 722.000000 ################
2. "fnlwgt" 515.000000 ###########
3. "age" 483.000000 ##########
4. "education" 459.000000 #########
5. "hours_per_week" 339.000000 #######
6. "capital_gain" 325.000000 ######
7. "native_country" 306.000000 ######
8. "capital_loss" 297.000000 ######
9. "relationship" 262.000000 #####
10. "workclass" 245.000000 #####
11. "marital_status" 210.000000 ####
12. "education_num" 88.000000 #
13. "sex" 41.000000
14. "race" 21.000000
1. "relationship" 3018.761866 ################
2. "capital_gain" 2065.521668 ##########
3. "education" 1241.764059 ######
4. "marital_status" 1107.545372 #####
5. "occupation" 1094.359168 #####
6. "education_num" 699.517705 ###
7. "capital_loss" 584.055066 ###
8. "age" 582.292563 ###
9. "hours_per_week" 366.856509 #
10. "native_country" 263.872689 #
11. "fnlwgt" 216.537764 #
12. "workclass" 196.221850 #
13. "sex" 46.986269
14. "race" 5.428727
Those variable importances are computed during training. More, and possibly more informative, variable importances are available when analyzing a model on a test dataset.
Only printing the first tree.
Tree #0:
"relationship" is in [BITMAP] {<OOD>, Husband, Wife} [s:0.036623 n:20533 np:9213 miss:1] ; pred:-4.15883e-09
├─(pos)─ "education" is in [BITMAP] {Bachelors, Masters, Prof-school, Doctorate} [s:0.0343752 n:9213 np:2773 miss:0] ; pred:0.116933
| ├─(pos)─ "capital_gain">=5095.5 [s:0.0125728 n:2773 np:434 miss:0] ; pred:0.272683
| | ├─(pos)─ "occupation" is in [BITMAP] {<OOD>, Prof-specialty, Exec-managerial, Craft-repair, Adm-clerical, Sales, Other-service, Machine-op-inspct, Transport-moving, Handlers-cleaners, ...[2 left]} [s:0.000434532 n:434 np:429 miss:1] ; pred:0.416173
| | | ├─(pos)─ "age">=79.5 [s:0.000449964 n:429 np:5 miss:0] ; pred:0.417414
| | | | ├─(pos)─ pred:0.309737
| | | | └─(neg)─ pred:0.418684
| | | └─(neg)─ pred:0.309737
| | └─(neg)─ "capital_loss">=1782.5 [s:0.0101181 n:2339 np:249 miss:0] ; pred:0.246058
| | ├─(pos)─ "capital_loss">=1989.5 [s:0.00201289 n:249 np:39 miss:0] ; pred:0.406701
| | | ├─(pos)─ pred:0.349312
| | | └─(neg)─ pred:0.417359
| | └─(neg)─ "occupation" is in [BITMAP] {Prof-specialty, Exec-managerial, Sales, Tech-support, Protective-serv} [s:0.0097175 n:2090 np:1688 miss:0] ; pred:0.226919
| | ├─(pos)─ pred:0.253437
| | └─(neg)─ pred:0.11557
| └─(neg)─ "capital_gain">=5095.5 [s:0.0205419 n:6440 np:303 miss:0] ; pred:0.0498685
| ├─(pos)─ "age">=60.5 [s:0.00421502 n:303 np:43 miss:0] ; pred:0.40543
| | ├─(pos)─ "occupation" is in [BITMAP] {Prof-specialty, Exec-managerial, Adm-clerical, Sales, Machine-op-inspct, Transport-moving, Handlers-cleaners} [s:0.0296244 n:43 np:25 miss:0] ; pred:0.317428
| | | ├─(pos)─ pred:0.397934
| | | └─(neg)─ pred:0.205614
| | └─(neg)─ "fnlwgt">=36212.5 [s:1.36643e-16 n:260 np:250 miss:1] ; pred:0.419984
| | ├─(pos)─ pred:0.419984
| | └─(neg)─ pred:0.419984
| └─(neg)─ "occupation" is in [BITMAP] {Prof-specialty, Exec-managerial, Adm-clerical, Sales, Tech-support, Protective-serv} [s:0.0100346 n:6137 np:2334 miss:0] ; pred:0.0323136
| ├─(pos)─ "age">=33.5 [s:0.00939348 n:2334 np:1769 miss:1] ; pred:0.102799
| | ├─(pos)─ pred:0.132992
| | └─(neg)─ pred:0.00826457
| └─(neg)─ "education_num">=8.5 [s:0.00478423 n:3803 np:2941 miss:1] ; pred:-0.0109452
| ├─(pos)─ pred:0.00969668
| └─(neg)─ pred:-0.0813718
└─(neg)─ "capital_gain">=7073.5 [s:0.0143125 n:11320 np:199 miss:0] ; pred:-0.0951681
├─(pos)─ "age">=21.5 [s:0.00807667 n:199 np:194 miss:1] ; pred:0.397823
| ├─(pos)─ "capital_gain">=7565.5 [s:0.00761118 n:194 np:184 miss:0] ; pred:0.405777
| | ├─(pos)─ "capital_gain">=30961.5 [s:0.000242202 n:184 np:20 miss:0] ; pred:0.416988
| | | ├─(pos)─ pred:0.392422
| | | └─(neg)─ pred:0.419984
| | └─(neg)─ "education" is in [BITMAP] {Bachelors, Masters, Assoc-voc, Prof-school} [s:0.16 n:10 np:5 miss:0] ; pred:0.19949
| | ├─(pos)─ pred:0.419984
| | └─(neg)─ pred:-0.0210046
| └─(neg)─ pred:0.0892425
└─(neg)─ "education_num">=12.5 [s:0.00229611 n:11121 np:2199 miss:0] ; pred:-0.10399
├─(pos)─ "age">=31.5 [s:0.00725859 n:2199 np:1263 miss:1] ; pred:-0.0507848
| ├─(pos)─ "education_num">=14.5 [s:0.0110157 n:1263 np:125 miss:0] ; pred:-0.0103552
| | ├─(pos)─ pred:0.16421
| | └─(neg)─ pred:-0.0295298
| └─(neg)─ "capital_loss">=1977 [s:0.00164232 n:936 np:5 miss:0] ; pred:-0.105339
| ├─(pos)─ pred:0.19949
| └─(neg)─ pred:-0.106976
└─(neg)─ "capital_loss">=2218.5 [s:0.000534265 n:8922 np:41 miss:0] ; pred:-0.117103
├─(pos)─ "fnlwgt">=125450 [s:0.0755454 n:41 np:28 miss:1] ; pred:0.0704198
| ├─(pos)─ pred:-0.0328167
| └─(neg)─ pred:0.292776
└─(neg)─ "hours_per_week">=40.5 [s:0.000447024 n:8881 np:1559 miss:0] ; pred:-0.117969
├─(pos)─ pred:-0.0927111
└─(neg)─ pred:-0.123347
Make predictions¶
To get predictions, simply use the model.predict() method. It returns the predictions as a NumPy array.
model.predict(test_ds)
array([0.01860435, 0.36130956, 0.83858865, ..., 0.03087652, 0.08280362,
0.00970956], shape=(9769,), dtype=float32)
Methods like train() and predict() are flexible and accept data in various formats, including Pandas DataFrames, dictionaries of lists or NumPy arrays, TensorFlow Datasets, or even file paths.
# Prediction with a dictionary
model.predict({
'age': [39],
'workclass': ['State-gov'],
'fnlwgt': [77516],
'education': ['Bachelors'],
'education_num': [13],
'marital_status': ['Never-married'],
'occupation': ['Adm-clerical'],
'relationship': ['Not-in-family'],
'race': ['White'],
'sex': ['Male'],
'capital_gain': [2174],
'capital_loss': [0],
'hours_per_week': [40],
'native_country': ['United-States'],
'income': ['<=50K'],
})
array([0.01860435], dtype=float32)
Evaluate model¶
While the internal validation set gives us a good idea of the model's quality, we should also evaluate its performance on the unseen test dataset.
evaluation = model.evaluate(test_ds)
# Query individual evaluation metrics
print(f"Test accuracy: {evaluation.accuracy}")
# Show the full evaluation report
print("Full evaluation report:")
evaluation
Test accuracy: 0.8737844201044119 Full evaluation report:
Evaluation of classification models
- Accuracy
- The simplest metric. It's the percentage of predictions that are correct (matching the ground truth).
Example: If a model correctly identifies 90 out of 100 images as cat or dog, the accuracy is 90%. - Confusion Matrix
- A table that shows the counts of:
- True Positives (TP): Model correctly predicted positive.
- True Negatives (TN): Model correctly predicted negative.
- False Positives (FP): Model incorrectly predicted positive (a "false alarm").
- False Negatives (FN): Model incorrectly predicted negative (a "miss").
- Threshold
- YDF classification models predict a probability for each class. A threshold determines the cutoff for classifying something as positive or negative.
Example: If the threshold is 0.5, any prediction above 0.5 might be classified as "spam," and anything below as "not spam." - ROC Curve (Receiver Operating Characteristic Curve)
- A graph that plots the True Positive Rate (TPR) against the False Positive Rate (FPR) at various thresholds.
- TPR (Sensitivity or Recall): TP / (TP + FN) - How many of the actual positives did the model catch?
- FPR: FP / (FP + TN) - How many negatives were incorrectly classified as positives?
Interpretation: A good model has an ROC curve that hugs the top-left corner (high TPR, low FPR). - AUC (Area Under the ROC Curve)
- A single number that summarizes the overall performance shown by the ROC curve. The AUC is a more stable metric than the accuracy. Multi-class classification models evaluate one class against all other classes.
Interpretation: Ranges from 0 to 1. A perfect model has an AUC of 1, while a random model has an AUC of 0.5. Higher is better. - Precision-Recall Curve
- A graph that plots Precision against Recall at various thresholds.
- Precision: TP / (TP + FP) - Out of all the predictions the model labeled as positive, how many were actually positive?
- Recall (same as TPR): TP / (TP + FN) - Out of all the actual positive cases, how many did the model correctly identify?
Interpretation: A good model has a curve that stays high (both high precision and high recall). It is especially useful when dealing with imbalanced datasets (e.g., when one class is much rarer than the other). - PR-AUC (Area Under the Precision-Recall Curve)
- Similar to AUC, but for the Precision-Recall curve. A single number summarizing performance. Multi-class classification models evaluate one class against all other classes. Higher is better.
- Threshold / Accuracy Curve
- A graph that shows how the model's accuracy changes as you vary the classification threshold.
- Threshold / Volume Curve
- A graph showing how the number of data points classified as positive changes as you vary the threshold.
| Label \ Pred | <=50K | >50K |
|---|---|---|
| <=50K | 6961 | 451 |
| >50K | 782 | 1575 |
Analyze model¶
With model.analyze(ds), you can gain deeper insights into your model's behavior. For instance, Partial Dependence Plots (PDP) show how the model's predictions change as a single feature's value changes.
model.analyze(test_ds, sampling=0.1)
Variable importances measure the importance of an input feature for a model.
1. "capital_gain" 0.050159 ################
2. "occupation" 0.018733 ######
3. "relationship" 0.017300 #####
4. "age" 0.016378 #####
5. "capital_loss" 0.015252 ####
6. "marital_status" 0.014638 ####
7. "education" 0.013103 ####
8. "hours_per_week" 0.010032 ###
9. "workclass" 0.004504 #
10. "education_num" 0.002559
11. "native_country" 0.000614
12. "race" 0.000512
13. "sex" -0.000102
14. "fnlwgt" -0.000307
1. "capital_gain" 0.241216 ################
2. "age" 0.046747 ###
3. "marital_status" 0.045682 ##
4. "capital_loss" 0.044526 ##
5. "relationship" 0.040599 ##
6. "occupation" 0.034399 ##
7. "education" 0.023547 #
8. "hours_per_week" 0.020454 #
9. "education_num" 0.006596
10. "workclass" 0.004845
11. "native_country" 0.003937
12. "fnlwgt" 0.003755
13. "sex" 0.002375
14. "race" 0.000968
1. "capital_gain" 0.057987 ################
2. "age" 0.032421 ########
3. "marital_status" 0.030082 ########
4. "relationship" 0.020362 #####
5. "occupation" 0.016123 ####
6. "capital_loss" 0.015147 ####
7. "hours_per_week" 0.012690 ###
8. "education" 0.011416 ###
9. "education_num" 0.003740
10. "workclass" 0.002466
11. "native_country" 0.001954
12. "sex" 0.001332
13. "fnlwgt" 0.001201
14. "race" 0.000515
1. "capital_gain" 0.240962 ################
2. "age" 0.046745 ###
3. "marital_status" 0.045676 ##
4. "capital_loss" 0.044507 ##
5. "relationship" 0.040542 ##
6. "occupation" 0.034391 ##
7. "education" 0.023545 #
8. "hours_per_week" 0.020452 #
9. "education_num" 0.006594
10. "workclass" 0.004848
11. "native_country" 0.003936
12. "fnlwgt" 0.003755
13. "sex" 0.002374
14. "race" 0.000968
1. "marital_status" 0.725291 ################
2. "age" 0.631228 #############
3. "relationship" 0.565436 ############
4. "capital_gain" 0.467118 #########
5. "occupation" 0.364170 #######
6. "education" 0.316347 ######
7. "hours_per_week" 0.281264 #####
8. "education_num" 0.155824 ##
9. "capital_loss" 0.142157 ##
10. "sex" 0.137268 ##
11. "fnlwgt" 0.092738 #
12. "native_country" 0.091148 #
13. "workclass" 0.080679
14. "race" 0.038975
1. "age" 0.226664 ################
2. "occupation" 0.219688 #############
3. "education" 0.218052 #############
4. "capital_gain" 0.214876 ############
5. "marital_status" 0.212828 ###########
6. "relationship" 0.205975 #########
7. "fnlwgt" 0.203878 ########
8. "hours_per_week" 0.203735 ########
9. "capital_loss" 0.196549 ######
10. "native_country" 0.190548 ####
11. "workclass" 0.187810 ###
12. "education_num" 0.181204 #
13. "race" 0.180495
14. "sex" 0.177632
1. "age" 26.000000 ################
2. "capital_gain" 26.000000 ################
3. "marital_status" 20.000000 ############
4. "relationship" 17.000000 ##########
5. "education" 15.000000 ########
6. "capital_loss" 14.000000 ########
7. "hours_per_week" 14.000000 ########
8. "fnlwgt" 10.000000 #####
9. "race" 9.000000 #####
10. "education_num" 4.000000 #
11. "sex" 4.000000 #
12. "occupation" 2.000000
13. "workclass" 1.000000
14. "native_country" 1.000000
1. "occupation" 722.000000 ################
2. "fnlwgt" 515.000000 ###########
3. "age" 483.000000 ##########
4. "education" 459.000000 #########
5. "hours_per_week" 339.000000 #######
6. "capital_gain" 325.000000 ######
7. "native_country" 306.000000 ######
8. "capital_loss" 297.000000 ######
9. "relationship" 262.000000 #####
10. "workclass" 245.000000 #####
11. "marital_status" 210.000000 ####
12. "education_num" 88.000000 #
13. "sex" 41.000000
14. "race" 21.000000
1. "relationship" 3018.761866 ################
2. "capital_gain" 2065.521668 ##########
3. "education" 1241.764059 ######
4. "marital_status" 1107.545372 #####
5. "occupation" 1094.359168 #####
6. "education_num" 699.517705 ###
7. "capital_loss" 584.055066 ###
8. "age" 582.292563 ###
9. "hours_per_week" 366.856509 #
10. "native_country" 263.872689 #
11. "fnlwgt" 216.537764 #
12. "workclass" 196.221850 #
13. "sex" 46.986269
14. "race" 5.428727
Benchmark model speed¶
For applications where inference speed is critical, you can use model.benchmark(ds) to measure its performance.
model.benchmark(test_ds)
Single-thread inference time per example: 0.718 us (microseconds) Details: 4190901 predictions in 0.000 seconds Multi-thread inference time per example: 0.059 us (microseconds) Details: 36663057 predictions in 0.000 seconds using 24 threads * Measured with the C++ serving API. See model.to_cpp().
This benchmark measures the inference speed using the underlying C++ API. The Python API introduces some overhead. If you need to benchmark the raw C++ speed, you can use model.to_cpp() to generate C++ code for a standalone benchmark.
print(model.to_cpp())
// Automatically generated code running an Yggdrasil Decision Forests model in
// C++. This code was generated with "model.to_cpp()".
//
// Date of generation: 2025-06-25 11:39:07.368064
// YDF Version: 0.12.0
//
// How to use this code:
//
// 1. Copy this code in a new .h file.
// 2. If you use Bazel/Blaze, use the following dependencies:
// //third_party/absl/status:statusor
// //third_party/absl/strings
// //third_party/yggdrasil_decision_forests/api:serving
// 3. In your existing code, include the .h file. Make predictions as follows:
// // Load the model (to do only once).
// namespace ydf = yggdrasil_decision_forests;
// const auto model = ydf::exported_model_123::Load(<path to model>);
// // Run the model
// predictions = model.Predict();
// 4. By default, the "Predict" function takes no inputs and creates fake
// examples. In practice, you want to add your input data as arguments to
// "Predict" and call "examples->Set..." functions accordingly.
// 5. (Bonus)
// Allocate one `examples` and `predictions` vector per thread and reuse them
// to speed-up the inference.
//
#ifndef YGGDRASIL_DECISION_FORESTS_GENERATED_MODEL_my_model
#define YGGDRASIL_DECISION_FORESTS_GENERATED_MODEL_my_model
#include <memory>
#include <vector>
#include "third_party/absl/status/statusor.h"
#include "third_party/absl/strings/string_view.h"
#include "third_party/yggdrasil_decision_forests/api/serving.h"
namespace yggdrasil_decision_forests {
namespace exported_model_my_model {
struct ServingModel {
std::vector<float> Predict() const;
// Compiled model.
std::unique_ptr<serving_api::FastEngine> engine;
// Index of the input features of the model.
//
// Non-owning pointer. The data is owned by the engine.
const serving_api::FeaturesDefinition* features;
// Number of output predictions for each example.
// Equal to 1 for regression, ranking and binary classification with compact
// format. Equal to the number of classes for classification.
int NumPredictionDimension() const {
return engine->NumPredictionDimension();
}
// Indexes of the input features.
serving_api::NumericalFeatureId feature_age;
serving_api::CategoricalFeatureId feature_workclass;
serving_api::NumericalFeatureId feature_fnlwgt;
serving_api::CategoricalFeatureId feature_education;
serving_api::NumericalFeatureId feature_education_num;
serving_api::CategoricalFeatureId feature_marital_status;
serving_api::CategoricalFeatureId feature_occupation;
serving_api::CategoricalFeatureId feature_relationship;
serving_api::CategoricalFeatureId feature_race;
serving_api::CategoricalFeatureId feature_sex;
serving_api::NumericalFeatureId feature_capital_gain;
serving_api::NumericalFeatureId feature_capital_loss;
serving_api::NumericalFeatureId feature_hours_per_week;
serving_api::CategoricalFeatureId feature_native_country;
};
// TODO: Pass input feature values to "Predict".
inline std::vector<float> ServingModel::Predict() const {
// Allocate memory for 2 examples. Alternatively, for speed-sensitive code,
// an "examples" object can be allocated for each thread and reused. It is
// okay to allocate more examples than needed.
const int num_examples = 2;
auto examples = engine->AllocateExamples(num_examples);
// Set all the values to be missing. The values may then be overridden by the
// "Set*" methods. If all the values are set with "Set*" methods,
// "FillMissing" can be skipped.
examples->FillMissing(*features);
// Example #0
examples->SetNumerical(/*example_idx=*/0, feature_age, 1.f, *features);
examples->SetCategorical(/*example_idx=*/0, feature_workclass, "A", *features);
examples->SetNumerical(/*example_idx=*/0, feature_fnlwgt, 1.f, *features);
examples->SetCategorical(/*example_idx=*/0, feature_education, "A", *features);
examples->SetNumerical(/*example_idx=*/0, feature_education_num, 1.f, *features);
examples->SetCategorical(/*example_idx=*/0, feature_marital_status, "A", *features);
examples->SetCategorical(/*example_idx=*/0, feature_occupation, "A", *features);
examples->SetCategorical(/*example_idx=*/0, feature_relationship, "A", *features);
examples->SetCategorical(/*example_idx=*/0, feature_race, "A", *features);
examples->SetCategorical(/*example_idx=*/0, feature_sex, "A", *features);
examples->SetNumerical(/*example_idx=*/0, feature_capital_gain, 1.f, *features);
examples->SetNumerical(/*example_idx=*/0, feature_capital_loss, 1.f, *features);
examples->SetNumerical(/*example_idx=*/0, feature_hours_per_week, 1.f, *features);
examples->SetCategorical(/*example_idx=*/0, feature_native_country, "A", *features);
// Example #1
examples->SetNumerical(/*example_idx=*/1, feature_age, 2.f, *features);
examples->SetCategorical(/*example_idx=*/1, feature_workclass, "B", *features);
examples->SetNumerical(/*example_idx=*/1, feature_fnlwgt, 2.f, *features);
examples->SetCategorical(/*example_idx=*/1, feature_education, "B", *features);
examples->SetNumerical(/*example_idx=*/1, feature_education_num, 2.f, *features);
examples->SetCategorical(/*example_idx=*/1, feature_marital_status, "B", *features);
examples->SetCategorical(/*example_idx=*/1, feature_occupation, "B", *features);
examples->SetCategorical(/*example_idx=*/1, feature_relationship, "B", *features);
examples->SetCategorical(/*example_idx=*/1, feature_race, "B", *features);
examples->SetCategorical(/*example_idx=*/1, feature_sex, "B", *features);
examples->SetNumerical(/*example_idx=*/1, feature_capital_gain, 2.f, *features);
examples->SetNumerical(/*example_idx=*/1, feature_capital_loss, 2.f, *features);
examples->SetNumerical(/*example_idx=*/1, feature_hours_per_week, 2.f, *features);
examples->SetCategorical(/*example_idx=*/1, feature_native_country, "B", *features);
// Run the model on the two examples.
//
// For speed-sensitive code, reuse the same predictions.
std::vector<float> predictions;
engine->Predict(*examples, num_examples, &predictions);
return predictions;
}
inline absl::StatusOr<ServingModel> Load(absl::string_view path) {
ServingModel m;
// Load the model
ASSIGN_OR_RETURN(auto model, serving_api::LoadModel(path));
// Compile the model into an inference engine.
ASSIGN_OR_RETURN(m.engine, model->BuildFastEngine());
// Index the input features of the model.
m.features = &m.engine->features();
// Index the input features.
ASSIGN_OR_RETURN(m.feature_age, m.features->GetNumericalFeatureId("age"));
ASSIGN_OR_RETURN(m.feature_workclass, m.features->GetCategoricalFeatureId("workclass"));
ASSIGN_OR_RETURN(m.feature_fnlwgt, m.features->GetNumericalFeatureId("fnlwgt"));
ASSIGN_OR_RETURN(m.feature_education, m.features->GetCategoricalFeatureId("education"));
ASSIGN_OR_RETURN(m.feature_education_num, m.features->GetNumericalFeatureId("education_num"));
ASSIGN_OR_RETURN(m.feature_marital_status, m.features->GetCategoricalFeatureId("marital_status"));
ASSIGN_OR_RETURN(m.feature_occupation, m.features->GetCategoricalFeatureId("occupation"));
ASSIGN_OR_RETURN(m.feature_relationship, m.features->GetCategoricalFeatureId("relationship"));
ASSIGN_OR_RETURN(m.feature_race, m.features->GetCategoricalFeatureId("race"));
ASSIGN_OR_RETURN(m.feature_sex, m.features->GetCategoricalFeatureId("sex"));
ASSIGN_OR_RETURN(m.feature_capital_gain, m.features->GetNumericalFeatureId("capital_gain"));
ASSIGN_OR_RETURN(m.feature_capital_loss, m.features->GetNumericalFeatureId("capital_loss"));
ASSIGN_OR_RETURN(m.feature_hours_per_week, m.features->GetNumericalFeatureId("hours_per_week"));
ASSIGN_OR_RETURN(m.feature_native_country, m.features->GetCategoricalFeatureId("native_country"));
return m;
}
} // namespace exported_model_my_model
} // namespace yggdrasil_decision_forests
#endif // YGGDRASIL_DECISION_FORESTS_GENERATED_MODEL_my_model
Save model¶
Finally, let's save our trained model so we can use it later without retraining.
model.save("/tmp/my_model")
Loading it back is just as easy:
loaded_model = ydf.load_model("/tmp/my_model")
loaded_model.describe()
Task : CLASSIFICATION
Label : income
Features (14) : age workclass fnlwgt education education_num marital_status occupation relationship race sex capital_gain capital_loss hours_per_week native_country
Weights : None
Trained with tuner : No
Model size : 2169 kB
Number of records: 22792 Number of columns: 15 Number of columns by type: CATEGORICAL: 9 (60%) NUMERICAL: 6 (40%) Columns: CATEGORICAL: 9 (60%) 0: "income" CATEGORICAL has-dict vocab-size:3 zero-ood-items most-frequent:"<=50K" 17308 (75.9389%) dtype:DTYPE_BYTES 2: "workclass" CATEGORICAL num-nas:1257 (5.51509%) has-dict vocab-size:8 num-oods:3 (0.0139308%) most-frequent:"Private" 15879 (73.7358%) dtype:DTYPE_BYTES 4: "education" CATEGORICAL has-dict vocab-size:17 zero-ood-items most-frequent:"HS-grad" 7340 (32.2043%) dtype:DTYPE_BYTES 6: "marital_status" CATEGORICAL has-dict vocab-size:8 zero-ood-items most-frequent:"Married-civ-spouse" 10431 (45.7661%) dtype:DTYPE_BYTES 7: "occupation" CATEGORICAL num-nas:1260 (5.52826%) has-dict vocab-size:14 num-oods:4 (0.018577%) most-frequent:"Prof-specialty" 2870 (13.329%) dtype:DTYPE_BYTES 8: "relationship" CATEGORICAL has-dict vocab-size:7 zero-ood-items most-frequent:"Husband" 9191 (40.3256%) dtype:DTYPE_BYTES 9: "race" CATEGORICAL has-dict vocab-size:6 zero-ood-items most-frequent:"White" 19467 (85.4115%) dtype:DTYPE_BYTES 10: "sex" CATEGORICAL has-dict vocab-size:3 zero-ood-items most-frequent:"Male" 15165 (66.5365%) dtype:DTYPE_BYTES 14: "native_country" CATEGORICAL num-nas:407 (1.78571%) has-dict vocab-size:41 num-oods:1 (0.00446728%) most-frequent:"United-States" 20436 (91.2933%) dtype:DTYPE_BYTES NUMERICAL: 6 (40%) 1: "age" NUMERICAL mean:38.6153 min:17 max:90 sd:13.661 dtype:DTYPE_INT64 3: "fnlwgt" NUMERICAL mean:189879 min:12285 max:1.4847e+06 sd:106423 dtype:DTYPE_INT64 5: "education_num" NUMERICAL mean:10.0927 min:1 max:16 sd:2.56427 dtype:DTYPE_INT64 11: "capital_gain" NUMERICAL mean:1081.9 min:0 max:99999 sd:7509.48 dtype:DTYPE_INT64 12: "capital_loss" NUMERICAL mean:87.2806 min:0 max:4356 sd:403.01 dtype:DTYPE_INT64 13: "hours_per_week" NUMERICAL mean:40.3955 min:1 max:99 sd:12.249 dtype:DTYPE_INT64 Terminology: nas: Number of non-available (i.e. missing) values. ood: Out of dictionary. manually-defined: Attribute whose type is manually defined by the user, i.e., the type was not automatically inferred. tokenized: The attribute value is obtained through tokenization. has-dict: The attribute is attached to a string dictionary e.g. a categorical attribute stored as a string. vocab-size: Number of unique values.
The following evaluation is computed on the validation or out-of-bag dataset.
Task: CLASSIFICATION
Label: income
Loss (BINOMIAL_LOG_LIKELIHOOD): 0.576143
Accuracy: 0.868526 CI95[W][0 1]
ErrorRate: : 0.131474
Confusion Table:
truth\prediction
<=50K >50K
<=50K 1557 107
>50K 190 405
Total: 2259
Variable importances measure the importance of an input feature for a model.
1. "age" 0.226664 ################
2. "occupation" 0.219688 #############
3. "education" 0.218052 #############
4. "capital_gain" 0.214876 ############
5. "marital_status" 0.212828 ###########
6. "relationship" 0.205975 #########
7. "fnlwgt" 0.203878 ########
8. "hours_per_week" 0.203735 ########
9. "capital_loss" 0.196549 ######
10. "native_country" 0.190548 ####
11. "workclass" 0.187810 ###
12. "education_num" 0.181204 #
13. "race" 0.180495
14. "sex" 0.177632
1. "age" 26.000000 ################
2. "capital_gain" 26.000000 ################
3. "marital_status" 20.000000 ############
4. "relationship" 17.000000 ##########
5. "education" 15.000000 ########
6. "capital_loss" 14.000000 ########
7. "hours_per_week" 14.000000 ########
8. "fnlwgt" 10.000000 #####
9. "race" 9.000000 #####
10. "education_num" 4.000000 #
11. "sex" 4.000000 #
12. "occupation" 2.000000
13. "workclass" 1.000000
14. "native_country" 1.000000
1. "occupation" 722.000000 ################
2. "fnlwgt" 515.000000 ###########
3. "age" 483.000000 ##########
4. "education" 459.000000 #########
5. "hours_per_week" 339.000000 #######
6. "capital_gain" 325.000000 ######
7. "native_country" 306.000000 ######
8. "capital_loss" 297.000000 ######
9. "relationship" 262.000000 #####
10. "workclass" 245.000000 #####
11. "marital_status" 210.000000 ####
12. "education_num" 88.000000 #
13. "sex" 41.000000
14. "race" 21.000000
1. "relationship" 3018.761866 ################
2. "capital_gain" 2065.521668 ##########
3. "education" 1241.764059 ######
4. "marital_status" 1107.545372 #####
5. "occupation" 1094.359168 #####
6. "education_num" 699.517705 ###
7. "capital_loss" 584.055066 ###
8. "age" 582.292563 ###
9. "hours_per_week" 366.856509 #
10. "native_country" 263.872689 #
11. "fnlwgt" 216.537764 #
12. "workclass" 196.221850 #
13. "sex" 46.986269
14. "race" 5.428727
Those variable importances are computed during training. More, and possibly more informative, variable importances are available when analyzing a model on a test dataset.
Only printing the first tree.
Tree #0:
"relationship" is in [BITMAP] {<OOD>, Husband, Wife} [s:0.036623 n:20533 np:9213 miss:1] ; pred:-4.15883e-09
├─(pos)─ "education" is in [BITMAP] {Bachelors, Masters, Prof-school, Doctorate} [s:0.0343752 n:9213 np:2773 miss:0] ; pred:0.116933
| ├─(pos)─ "capital_gain">=5095.5 [s:0.0125728 n:2773 np:434 miss:0] ; pred:0.272683
| | ├─(pos)─ "occupation" is in [BITMAP] {<OOD>, Prof-specialty, Exec-managerial, Craft-repair, Adm-clerical, Sales, Other-service, Machine-op-inspct, Transport-moving, Handlers-cleaners, ...[2 left]} [s:0.000434532 n:434 np:429 miss:1] ; pred:0.416173
| | | ├─(pos)─ "age">=79.5 [s:0.000449964 n:429 np:5 miss:0] ; pred:0.417414
| | | | ├─(pos)─ pred:0.309737
| | | | └─(neg)─ pred:0.418684
| | | └─(neg)─ pred:0.309737
| | └─(neg)─ "capital_loss">=1782.5 [s:0.0101181 n:2339 np:249 miss:0] ; pred:0.246058
| | ├─(pos)─ "capital_loss">=1989.5 [s:0.00201289 n:249 np:39 miss:0] ; pred:0.406701
| | | ├─(pos)─ pred:0.349312
| | | └─(neg)─ pred:0.417359
| | └─(neg)─ "occupation" is in [BITMAP] {Prof-specialty, Exec-managerial, Sales, Tech-support, Protective-serv} [s:0.0097175 n:2090 np:1688 miss:0] ; pred:0.226919
| | ├─(pos)─ pred:0.253437
| | └─(neg)─ pred:0.11557
| └─(neg)─ "capital_gain">=5095.5 [s:0.0205419 n:6440 np:303 miss:0] ; pred:0.0498685
| ├─(pos)─ "age">=60.5 [s:0.00421502 n:303 np:43 miss:0] ; pred:0.40543
| | ├─(pos)─ "occupation" is in [BITMAP] {Prof-specialty, Exec-managerial, Adm-clerical, Sales, Machine-op-inspct, Transport-moving, Handlers-cleaners} [s:0.0296244 n:43 np:25 miss:0] ; pred:0.317428
| | | ├─(pos)─ pred:0.397934
| | | └─(neg)─ pred:0.205614
| | └─(neg)─ "fnlwgt">=36212.5 [s:1.36643e-16 n:260 np:250 miss:1] ; pred:0.419984
| | ├─(pos)─ pred:0.419984
| | └─(neg)─ pred:0.419984
| └─(neg)─ "occupation" is in [BITMAP] {Prof-specialty, Exec-managerial, Adm-clerical, Sales, Tech-support, Protective-serv} [s:0.0100346 n:6137 np:2334 miss:0] ; pred:0.0323136
| ├─(pos)─ "age">=33.5 [s:0.00939348 n:2334 np:1769 miss:1] ; pred:0.102799
| | ├─(pos)─ pred:0.132992
| | └─(neg)─ pred:0.00826457
| └─(neg)─ "education_num">=8.5 [s:0.00478423 n:3803 np:2941 miss:1] ; pred:-0.0109452
| ├─(pos)─ pred:0.00969668
| └─(neg)─ pred:-0.0813718
└─(neg)─ "capital_gain">=7073.5 [s:0.0143125 n:11320 np:199 miss:0] ; pred:-0.0951681
├─(pos)─ "age">=21.5 [s:0.00807667 n:199 np:194 miss:1] ; pred:0.397823
| ├─(pos)─ "capital_gain">=7565.5 [s:0.00761118 n:194 np:184 miss:0] ; pred:0.405777
| | ├─(pos)─ "capital_gain">=30961.5 [s:0.000242202 n:184 np:20 miss:0] ; pred:0.416988
| | | ├─(pos)─ pred:0.392422
| | | └─(neg)─ pred:0.419984
| | └─(neg)─ "education" is in [BITMAP] {Bachelors, Masters, Assoc-voc, Prof-school} [s:0.16 n:10 np:5 miss:0] ; pred:0.19949
| | ├─(pos)─ pred:0.419984
| | └─(neg)─ pred:-0.0210046
| └─(neg)─ pred:0.0892425
└─(neg)─ "education_num">=12.5 [s:0.00229611 n:11121 np:2199 miss:0] ; pred:-0.10399
├─(pos)─ "age">=31.5 [s:0.00725859 n:2199 np:1263 miss:1] ; pred:-0.0507848
| ├─(pos)─ "education_num">=14.5 [s:0.0110157 n:1263 np:125 miss:0] ; pred:-0.0103552
| | ├─(pos)─ pred:0.16421
| | └─(neg)─ pred:-0.0295298
| └─(neg)─ "capital_loss">=1977 [s:0.00164232 n:936 np:5 miss:0] ; pred:-0.105339
| ├─(pos)─ pred:0.19949
| └─(neg)─ pred:-0.106976
└─(neg)─ "capital_loss">=2218.5 [s:0.000534265 n:8922 np:41 miss:0] ; pred:-0.117103
├─(pos)─ "fnlwgt">=125450 [s:0.0755454 n:41 np:28 miss:1] ; pred:0.0704198
| ├─(pos)─ pred:-0.0328167
| └─(neg)─ pred:0.292776
└─(neg)─ "hours_per_week">=40.5 [s:0.000447024 n:8881 np:1559 miss:0] ; pred:-0.117969
├─(pos)─ pred:-0.0927111
└─(neg)─ pred:-0.123347
Conclusion¶
That's it! You now know the basics of using YDF. 😊
To learn more, check out the other tutorials on ydf.readthedocs.io. For example, you can discover how to:
- Train models for ranking, regression, or uplifting using the task argument.
- Find the nearest neighbors between examples with
model.distance(). - Enforce monotonic constraints on your features with the
featuresargument. - Convert the model to a TensorFlow SavedModel for serving with
model.to_tensorflow_saved_model(). - Use feature selection to improve training time and model quality.
- Train on billions of examples with distributed training.