🧭 Getting started¶
Decision Forests (DFs) are machine learning algorithms for classification, regression uplifting, and ranking. As the name suggests, DFs are built from decision trees. Today, the two most popular DF training algorithms are Random Forests and Gradient Boosted Decision Trees.
Yggdrasil Decision Forests (YDF) is a library to train, evaluate, understand, and serve decision forest models. YDF is available in multiple languages: Python, C++, CLI, and TensorFlow, under the name TensorFlow Decision Forests. This notebook demonstrates the Python API, which is the recommended way to use YDF.
For the API Reference and other tutorials, check the YDF website.
Install YDF¶
pip install ydf -U
Import libraries¶
import ydf # Yggdrasil Decision Forests
import pandas as pd # We use Pandas to load small datasets
Download and load dataset¶
We use the binary classification Adult. The objective is to predict the value of the income
column, which can be either <50k
or >=50k
, using the other numerical and categorical columns. This dataset contains missing values.
ds_path = "https://raw.githubusercontent.com/google/yggdrasil-decision-forests/main/yggdrasil_decision_forests/test_data/dataset"
# Download and load the dataset as Pandas DataFrames
train_ds = pd.read_csv(f"{ds_path}/adult_train.csv")
test_ds = pd.read_csv(f"{ds_path}/adult_test.csv")
# Print the first 5 training examples
train_ds.head(5)
age | workclass | fnlwgt | education | education_num | marital_status | occupation | relationship | race | sex | capital_gain | capital_loss | hours_per_week | native_country | income | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 44 | Private | 228057 | 7th-8th | 4 | Married-civ-spouse | Machine-op-inspct | Wife | White | Female | 0 | 0 | 40 | Dominican-Republic | <=50K |
1 | 20 | Private | 299047 | Some-college | 10 | Never-married | Other-service | Not-in-family | White | Female | 0 | 0 | 20 | United-States | <=50K |
2 | 40 | Private | 342164 | HS-grad | 9 | Separated | Adm-clerical | Unmarried | White | Female | 0 | 0 | 37 | United-States | <=50K |
3 | 30 | Private | 361742 | Some-college | 10 | Married-civ-spouse | Exec-managerial | Husband | White | Male | 0 | 0 | 50 | United-States | <=50K |
4 | 67 | Self-emp-inc | 171564 | HS-grad | 9 | Married-civ-spouse | Prof-specialty | Wife | White | Female | 20051 | 0 | 30 | England | >50K |
Train a model¶
Let's train a gradient boosted trees model using default values for all the hyper-parameters.
model = ydf.GradientBoostedTreesLearner(label="income").train(train_ds)
Train model on 22792 examples Model trained in 0:00:03.698584
Remarks
- YDF makes a difference between learning algorithms (a.k.a. learners such as
GradientBoostedTreesLearner
) and models. Later, in more advanced examples, you will see why we do it :). - The only required parameter for a learner is
label
. Other parameters have good default values. - We did not specify input features, so all the columns are used as input features. The type of features is automatically detected (e.g. numerical, categorical, boolean, text, with possibly missing values) and ingested.
- By default, learners train classification models. Other tasks (e.g., regression, ranking, uplifting) can be configured with the task parameter e.g.
task=ydf.Task.REGRESSION
. - Training logs can be shown during training with the
verbose=2
argument, or after training withmodel.describe()
. This is useful for debugging and understanding the training process. - A validation dataset was not specified. In this case, learners such as
GradientBoostedTreesLearner
will extract data from the training dataset that can be used for validation. Other learners such asRandomForestLearner
do not require a validation dataset and will use all the data for training.
Looking at model¶
With model.describe()
, we can look at:
- Model: The model task, input features and size.
- Dataspec: The type of statistics about all the input features.
- Training: The training and validation loss and metrics.
- Tuning (only if hyper-parameter tuning is enable): The tuning logs.
- Variable importance: What features matter most to the model.
- Structure: The trees in the model.
model.describe()
Task : CLASSIFICATION
Label : income
Features (14) : age workclass fnlwgt education education_num marital_status occupation relationship race sex capital_gain capital_loss hours_per_week native_country
Weights : None
Trained with tuner : No
Model size : 2174 kB
Number of records: 22792 Number of columns: 15 Number of columns by type: CATEGORICAL: 9 (60%) NUMERICAL: 6 (40%) Columns: CATEGORICAL: 9 (60%) 0: "income" CATEGORICAL has-dict vocab-size:3 zero-ood-items most-frequent:"<=50K" 17308 (75.9389%) 2: "workclass" CATEGORICAL num-nas:1257 (5.51509%) has-dict vocab-size:8 num-oods:3 (0.0139308%) most-frequent:"Private" 15879 (73.7358%) 4: "education" CATEGORICAL has-dict vocab-size:17 zero-ood-items most-frequent:"HS-grad" 7340 (32.2043%) 6: "marital_status" CATEGORICAL has-dict vocab-size:8 zero-ood-items most-frequent:"Married-civ-spouse" 10431 (45.7661%) 7: "occupation" CATEGORICAL num-nas:1260 (5.52826%) has-dict vocab-size:14 num-oods:4 (0.018577%) most-frequent:"Prof-specialty" 2870 (13.329%) 8: "relationship" CATEGORICAL has-dict vocab-size:7 zero-ood-items most-frequent:"Husband" 9191 (40.3256%) 9: "race" CATEGORICAL has-dict vocab-size:6 zero-ood-items most-frequent:"White" 19467 (85.4115%) 10: "sex" CATEGORICAL has-dict vocab-size:3 zero-ood-items most-frequent:"Male" 15165 (66.5365%) 14: "native_country" CATEGORICAL num-nas:407 (1.78571%) has-dict vocab-size:41 num-oods:1 (0.00446728%) most-frequent:"United-States" 20436 (91.2933%) NUMERICAL: 6 (40%) 1: "age" NUMERICAL mean:38.6153 min:17 max:90 sd:13.661 3: "fnlwgt" NUMERICAL mean:189879 min:12285 max:1.4847e+06 sd:106423 5: "education_num" NUMERICAL mean:10.0927 min:1 max:16 sd:2.56427 11: "capital_gain" NUMERICAL mean:1081.9 min:0 max:99999 sd:7509.48 12: "capital_loss" NUMERICAL mean:87.2806 min:0 max:4356 sd:403.01 13: "hours_per_week" NUMERICAL mean:40.3955 min:1 max:99 sd:12.249 Terminology: nas: Number of non-available (i.e. missing) values. ood: Out of dictionary. manually-defined: Attribute whose type is manually defined by the user, i.e., the type was not automatically inferred. tokenized: The attribute value is obtained through tokenization. has-dict: The attribute is attached to a string dictionary e.g. a categorical attribute stored as a string. vocab-size: Number of unique values.
The following evaluation is computed on the validation or out-of-bag dataset.
Task: CLASSIFICATION Label: income Loss (BINOMIAL_LOG_LIKELIHOOD): 0.576162 Accuracy: 0.868526 CI95[W][0 1] ErrorRate: : 0.131474 Confusion Table: truth\prediction <=50K >50K <=50K 1557 107 >50K 190 405 Total: 2259
Variable importances measure the importance of an input feature for a model.
1. "age" 0.226642 ################ 2. "occupation" 0.219727 ############# 3. "capital_gain" 0.214876 ############ 4. "education" 0.213746 ########### 5. "marital_status" 0.212739 ########### 6. "relationship" 0.206040 ######### 7. "fnlwgt" 0.203843 ######## 8. "hours_per_week" 0.203735 ######## 9. "capital_loss" 0.196549 ###### 10. "native_country" 0.190548 #### 11. "workclass" 0.187795 ### 12. "education_num" 0.184215 ## 13. "race" 0.180495 14. "sex" 0.177647
1. "age" 26.000000 ################ 2. "capital_gain" 26.000000 ################ 3. "marital_status" 20.000000 ############ 4. "relationship" 17.000000 ########## 5. "capital_loss" 14.000000 ######## 6. "hours_per_week" 14.000000 ######## 7. "education" 12.000000 ####### 8. "fnlwgt" 10.000000 ##### 9. "race" 9.000000 ##### 10. "education_num" 7.000000 ### 11. "sex" 4.000000 # 12. "occupation" 2.000000 13. "workclass" 1.000000 14. "native_country" 1.000000
1. "occupation" 724.000000 ################ 2. "fnlwgt" 513.000000 ########### 3. "age" 483.000000 ########## 4. "education" 464.000000 ########## 5. "hours_per_week" 339.000000 ####### 6. "capital_gain" 326.000000 ###### 7. "native_country" 306.000000 ###### 8. "capital_loss" 297.000000 ###### 9. "relationship" 262.000000 ##### 10. "workclass" 244.000000 ##### 11. "marital_status" 210.000000 #### 12. "education_num" 82.000000 # 13. "sex" 42.000000 14. "race" 21.000000
1. "relationship" 3014.690076 ################ 2. "capital_gain" 2065.521668 ########## 3. "education" 1144.490954 ###### 4. "marital_status" 1111.389695 ##### 5. "occupation" 1094.619502 ##### 6. "education_num" 796.666823 #### 7. "capital_loss" 584.055066 ### 8. "age" 582.288569 ### 9. "hours_per_week" 366.856509 # 10. "native_country" 263.872689 # 11. "fnlwgt" 216.537764 # 12. "workclass" 196.085503 # 13. "sex" 47.217730 14. "race" 5.428727
Those variable importances are computed during training. More, and possibly more informative, variable importances are available when analyzing a model on a test dataset.
Only printing the first tree.
Tree #0: "relationship" is in [BITMAP] {<OOD>, Husband, Wife} [s:0.036623 n:20533 np:9213 miss:1] ; pred:-4.15883e-09 ├─(pos)─ "education_num">=12.5 [s:0.0343752 n:9213 np:2773 miss:0] ; pred:0.116933 | ├─(pos)─ "capital_gain">=5095.5 [s:0.0125728 n:2773 np:434 miss:0] ; pred:0.272683 | | ├─(pos)─ "occupation" is in [BITMAP] {<OOD>, Prof-specialty, Exec-managerial, Craft-repair, Adm-clerical, Sales, Other-service, Machine-op-inspct, Transport-moving, Handlers-cleaners, ...[2 left]} [s:0.000434532 n:434 np:429 miss:1] ; pred:0.416173 | | | ├─(pos)─ "age">=79.5 [s:0.000449964 n:429 np:5 miss:0] ; pred:0.417414 | | | | ├─(pos)─ pred:0.309737 | | | | └─(neg)─ pred:0.418684 | | | └─(neg)─ pred:0.309737 | | └─(neg)─ "capital_loss">=1782.5 [s:0.0101181 n:2339 np:249 miss:0] ; pred:0.246058 | | ├─(pos)─ "capital_loss">=1989.5 [s:0.00201289 n:249 np:39 miss:0] ; pred:0.406701 | | | ├─(pos)─ pred:0.349312 | | | └─(neg)─ pred:0.417359 | | └─(neg)─ "occupation" is in [BITMAP] {Prof-specialty, Exec-managerial, Sales, Tech-support, Protective-serv} [s:0.0097175 n:2090 np:1688 miss:0] ; pred:0.226919 | | ├─(pos)─ pred:0.253437 | | └─(neg)─ pred:0.11557 | └─(neg)─ "capital_gain">=5095.5 [s:0.0205419 n:6440 np:303 miss:0] ; pred:0.0498685 | ├─(pos)─ "age">=60.5 [s:0.00421502 n:303 np:43 miss:0] ; pred:0.40543 | | ├─(pos)─ "occupation" is in [BITMAP] {Prof-specialty, Exec-managerial, Adm-clerical, Sales, Machine-op-inspct, Transport-moving, Handlers-cleaners} [s:0.0296244 n:43 np:25 miss:0] ; pred:0.317428 | | | ├─(pos)─ pred:0.397934 | | | └─(neg)─ pred:0.205614 | | └─(neg)─ "fnlwgt">=36212.5 [s:1.36643e-16 n:260 np:250 miss:1] ; pred:0.419984 | | ├─(pos)─ pred:0.419984 | | └─(neg)─ pred:0.419984 | └─(neg)─ "occupation" is in [BITMAP] {Prof-specialty, Exec-managerial, Adm-clerical, Sales, Tech-support, Protective-serv} [s:0.0100346 n:6137 np:2334 miss:0] ; pred:0.0323136 | ├─(pos)─ "age">=33.5 [s:0.00939348 n:2334 np:1769 miss:1] ; pred:0.102799 | | ├─(pos)─ pred:0.132992 | | └─(neg)─ pred:0.00826457 | └─(neg)─ "education" is in [BITMAP] {<OOD>, HS-grad, Some-college, Bachelors, Masters, Assoc-voc, Assoc-acdm, Prof-school, Doctorate} [s:0.00478423 n:3803 np:2941 miss:1] ; pred:-0.0109452 | ├─(pos)─ pred:0.00969668 | └─(neg)─ pred:-0.0813718 └─(neg)─ "capital_gain">=7073.5 [s:0.0143125 n:11320 np:199 miss:0] ; pred:-0.0951681 ├─(pos)─ "age">=21.5 [s:0.00807667 n:199 np:194 miss:1] ; pred:0.397823 | ├─(pos)─ "capital_gain">=7565.5 [s:0.00761118 n:194 np:184 miss:0] ; pred:0.405777 | | ├─(pos)─ "capital_gain">=30961.5 [s:0.000242202 n:184 np:20 miss:0] ; pred:0.416988 | | | ├─(pos)─ pred:0.392422 | | | └─(neg)─ pred:0.419984 | | └─(neg)─ "education" is in [BITMAP] {Bachelors, Masters, Assoc-voc, Prof-school} [s:0.16 n:10 np:5 miss:0] ; pred:0.19949 | | ├─(pos)─ pred:0.419984 | | └─(neg)─ pred:-0.0210046 | └─(neg)─ pred:0.0892425 └─(neg)─ "education" is in [BITMAP] {<OOD>, Bachelors, Masters, Prof-school, Doctorate} [s:0.00229611 n:11121 np:2199 miss:1] ; pred:-0.10399 ├─(pos)─ "age">=31.5 [s:0.00725859 n:2199 np:1263 miss:1] ; pred:-0.0507848 | ├─(pos)─ "education" is in [BITMAP] {<OOD>, HS-grad, Some-college, Assoc-voc, 11th, Assoc-acdm, 10th, 7th-8th, Prof-school, 9th, ...[5 left]} [s:0.0110157 n:1263 np:125 miss:1] ; pred:-0.0103552 | | ├─(pos)─ pred:0.16421 | | └─(neg)─ pred:-0.0295298 | └─(neg)─ "capital_loss">=1977 [s:0.00164232 n:936 np:5 miss:0] ; pred:-0.105339 | ├─(pos)─ pred:0.19949 | └─(neg)─ pred:-0.106976 └─(neg)─ "capital_loss">=2218.5 [s:0.000534265 n:8922 np:41 miss:0] ; pred:-0.117103 ├─(pos)─ "fnlwgt">=125450 [s:0.0755454 n:41 np:28 miss:1] ; pred:0.0704198 | ├─(pos)─ pred:-0.0328167 | └─(neg)─ pred:0.292776 └─(neg)─ "hours_per_week">=40.5 [s:0.000447024 n:8881 np:1559 miss:0] ; pred:-0.117969 ├─(pos)─ pred:-0.0927111 └─(neg)─ pred:-0.123347
Make predictions¶
model.predict(ds)
applies a model and returns the predictions as a Numpy array.
model.predict(test_ds)
array([0.01860435, 0.36130956, 0.83858865, ..., 0.03087652, 0.08280362, 0.00970956], dtype=float32)
Methods that consume datasets, such as train
and predict
, support multiple dataset formats such as Pandas DataFrames, dictionaries of lists or Numpy arrays, TensorFlow Datasets, and event file paths!
# Prediction with a dictionary
model.predict({
'age': [39],
'workclass': ['State-gov'],
'fnlwgt': [77516],
'education': ['Bachelors'],
'education_num': [13],
'marital_status': ['Never-married'],
'occupation': ['Adm-clerical'],
'relationship': ['Not-in-family'],
'race': ['White'],
'sex': ['Male'],
'capital_gain': [2174],
'capital_loss': [0],
'hours_per_week': [40],
'native_country': ['United-States'],
'income': ['<=50K'],
})
array([0.01860435], dtype=float32)
Evaluate model¶
While the validation dataset above provides an indication of the model's quality, we also want to evaluate the model on the test dataset.
evaluation = model.evaluate(test_ds)
# Query individual evaluation metrics
print(f"test accuracy: {evaluation.accuracy}")
# Show the full evaluation report
print("Full evaluation report:")
evaluation
test accuracy: 0.8738867847271983 Full evaluation report:
Label \ Pred | <=50K | >50K |
---|---|---|
<=50K | 6962 | 782 |
>50K | 450 | 1575 |