🧭 Getting started¶
Decision Forests (DFs) are machine learning algorithms for classification, regression uplifting, and ranking. As the name suggests, DFs are built from decision trees. Today, the two most popular DF training algorithms are Random Forests and Gradient Boosted Decision Trees.
Yggdrasil Decision Forests (YDF) is a library to train, evaluate, understand, and serve decision forest models. YDF is available in multiple languages: Python, C++, CLI, and TensorFlow, under the name TensorFlow Decision Forests. This notebook demonstrates the Python API, which is the recommended way to use YDF.
For the API Reference and other tutorials, check the YDF website.
Install YDF¶
pip install ydf -U
Import libraries¶
import ydf # Yggdrasil Decision Forests
import pandas as pd # We use Pandas to load small datasets
Download and load dataset¶
We use the binary classification Adult. The objective is to predict the value of the income
column, which can be either <50k
or >=50k
, using the other numerical and categorical columns. This dataset contains missing values.
ds_path = "https://raw.githubusercontent.com/google/yggdrasil-decision-forests/main/yggdrasil_decision_forests/test_data/dataset"
# Download and load the dataset as Pandas DataFrames
train_ds = pd.read_csv(f"{ds_path}/adult_train.csv")
test_ds = pd.read_csv(f"{ds_path}/adult_test.csv")
# Print the first 5 training examples
train_ds.head(5)
age | workclass | fnlwgt | education | education_num | marital_status | occupation | relationship | race | sex | capital_gain | capital_loss | hours_per_week | native_country | income | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 44 | Private | 228057 | 7th-8th | 4 | Married-civ-spouse | Machine-op-inspct | Wife | White | Female | 0 | 0 | 40 | Dominican-Republic | <=50K |
1 | 20 | Private | 299047 | Some-college | 10 | Never-married | Other-service | Not-in-family | White | Female | 0 | 0 | 20 | United-States | <=50K |
2 | 40 | Private | 342164 | HS-grad | 9 | Separated | Adm-clerical | Unmarried | White | Female | 0 | 0 | 37 | United-States | <=50K |
3 | 30 | Private | 361742 | Some-college | 10 | Married-civ-spouse | Exec-managerial | Husband | White | Male | 0 | 0 | 50 | United-States | <=50K |
4 | 67 | Self-emp-inc | 171564 | HS-grad | 9 | Married-civ-spouse | Prof-specialty | Wife | White | Female | 20051 | 0 | 30 | England | >50K |
Train a model¶
Let's train a gradient boosted trees model using default values for all the hyper-parameters.
model = ydf.GradientBoostedTreesLearner(label="income").train(train_ds)
Train model on 22792 examples Model trained in 0:00:03.698584
Remarks
- YDF makes a difference between learning algorithms (a.k.a. learners such as
GradientBoostedTreesLearner
) and models. Later, in more advanced examples, you will see why we do it :). - The only required parameter for a learner is
label
. Other parameters have good default values. - We did not specify input features, so all the columns are used as input features. The type of features is automatically detected (e.g. numerical, categorical, boolean, text, with possibly missing values) and ingested.
- By default, learners train classification models. Other tasks (e.g., regression, ranking, uplifting) can be configured with the task parameter e.g.
task=ydf.Task.REGRESSION
. - Training logs can be shown during training with the
verbose=2
argument, or after training withmodel.describe()
. This is useful for debugging and understanding the training process. - A validation dataset was not specified. In this case, learners such as
GradientBoostedTreesLearner
will extract data from the training dataset that can be used for validation. Other learners such asRandomForestLearner
do not require a validation dataset and will use all the data for training.
Looking at model¶
With model.describe()
, we can look at:
- Model: The model task, input features and size.
- Dataspec: The type of statistics about all the input features.
- Training: The training and validation loss and metrics.
- Tuning (only if hyper-parameter tuning is enable): The tuning logs.
- Variable importance: What features matter most to the model.
- Structure: The trees in the model.
model.describe()
Task : CLASSIFICATION
Label : income
Features (14) : age workclass fnlwgt education education_num marital_status occupation relationship race sex capital_gain capital_loss hours_per_week native_country
Weights : None
Trained with tuner : No
Model size : 2174 kB
Number of records: 22792 Number of columns: 15 Number of columns by type: CATEGORICAL: 9 (60%) NUMERICAL: 6 (40%) Columns: CATEGORICAL: 9 (60%) 0: "income" CATEGORICAL has-dict vocab-size:3 zero-ood-items most-frequent:"<=50K" 17308 (75.9389%) 2: "workclass" CATEGORICAL num-nas:1257 (5.51509%) has-dict vocab-size:8 num-oods:3 (0.0139308%) most-frequent:"Private" 15879 (73.7358%) 4: "education" CATEGORICAL has-dict vocab-size:17 zero-ood-items most-frequent:"HS-grad" 7340 (32.2043%) 6: "marital_status" CATEGORICAL has-dict vocab-size:8 zero-ood-items most-frequent:"Married-civ-spouse" 10431 (45.7661%) 7: "occupation" CATEGORICAL num-nas:1260 (5.52826%) has-dict vocab-size:14 num-oods:4 (0.018577%) most-frequent:"Prof-specialty" 2870 (13.329%) 8: "relationship" CATEGORICAL has-dict vocab-size:7 zero-ood-items most-frequent:"Husband" 9191 (40.3256%) 9: "race" CATEGORICAL has-dict vocab-size:6 zero-ood-items most-frequent:"White" 19467 (85.4115%) 10: "sex" CATEGORICAL has-dict vocab-size:3 zero-ood-items most-frequent:"Male" 15165 (66.5365%) 14: "native_country" CATEGORICAL num-nas:407 (1.78571%) has-dict vocab-size:41 num-oods:1 (0.00446728%) most-frequent:"United-States" 20436 (91.2933%) NUMERICAL: 6 (40%) 1: "age" NUMERICAL mean:38.6153 min:17 max:90 sd:13.661 3: "fnlwgt" NUMERICAL mean:189879 min:12285 max:1.4847e+06 sd:106423 5: "education_num" NUMERICAL mean:10.0927 min:1 max:16 sd:2.56427 11: "capital_gain" NUMERICAL mean:1081.9 min:0 max:99999 sd:7509.48 12: "capital_loss" NUMERICAL mean:87.2806 min:0 max:4356 sd:403.01 13: "hours_per_week" NUMERICAL mean:40.3955 min:1 max:99 sd:12.249 Terminology: nas: Number of non-available (i.e. missing) values. ood: Out of dictionary. manually-defined: Attribute whose type is manually defined by the user, i.e., the type was not automatically inferred. tokenized: The attribute value is obtained through tokenization. has-dict: The attribute is attached to a string dictionary e.g. a categorical attribute stored as a string. vocab-size: Number of unique values.
The following evaluation is computed on the validation or out-of-bag dataset.
Task: CLASSIFICATION Label: income Loss (BINOMIAL_LOG_LIKELIHOOD): 0.576162 Accuracy: 0.868526 CI95[W][0 1] ErrorRate: : 0.131474 Confusion Table: truth\prediction <=50K >50K <=50K 1557 107 >50K 190 405 Total: 2259
Variable importances measure the importance of an input feature for a model.
1. "age" 0.226642 ################ 2. "occupation" 0.219727 ############# 3. "capital_gain" 0.214876 ############ 4. "education" 0.213746 ########### 5. "marital_status" 0.212739 ########### 6. "relationship" 0.206040 ######### 7. "fnlwgt" 0.203843 ######## 8. "hours_per_week" 0.203735 ######## 9. "capital_loss" 0.196549 ###### 10. "native_country" 0.190548 #### 11. "workclass" 0.187795 ### 12. "education_num" 0.184215 ## 13. "race" 0.180495 14. "sex" 0.177647
1. "age" 26.000000 ################ 2. "capital_gain" 26.000000 ################ 3. "marital_status" 20.000000 ############ 4. "relationship" 17.000000 ########## 5. "capital_loss" 14.000000 ######## 6. "hours_per_week" 14.000000 ######## 7. "education" 12.000000 ####### 8. "fnlwgt" 10.000000 ##### 9. "race" 9.000000 ##### 10. "education_num" 7.000000 ### 11. "sex" 4.000000 # 12. "occupation" 2.000000 13. "workclass" 1.000000 14. "native_country" 1.000000
1. "occupation" 724.000000 ################ 2. "fnlwgt" 513.000000 ########### 3. "age" 483.000000 ########## 4. "education" 464.000000 ########## 5. "hours_per_week" 339.000000 ####### 6. "capital_gain" 326.000000 ###### 7. "native_country" 306.000000 ###### 8. "capital_loss" 297.000000 ###### 9. "relationship" 262.000000 ##### 10. "workclass" 244.000000 ##### 11. "marital_status" 210.000000 #### 12. "education_num" 82.000000 # 13. "sex" 42.000000 14. "race" 21.000000
1. "relationship" 3014.690076 ################ 2. "capital_gain" 2065.521668 ########## 3. "education" 1144.490954 ###### 4. "marital_status" 1111.389695 ##### 5. "occupation" 1094.619502 ##### 6. "education_num" 796.666823 #### 7. "capital_loss" 584.055066 ### 8. "age" 582.288569 ### 9. "hours_per_week" 366.856509 # 10. "native_country" 263.872689 # 11. "fnlwgt" 216.537764 # 12. "workclass" 196.085503 # 13. "sex" 47.217730 14. "race" 5.428727
Those variable importances are computed during training. More, and possibly more informative, variable importances are available when analyzing a model on a test dataset.
Only printing the first tree.
Tree #0: "relationship" is in [BITMAP] {<OOD>, Husband, Wife} [s:0.036623 n:20533 np:9213 miss:1] ; pred:-4.15883e-09 ├─(pos)─ "education_num">=12.5 [s:0.0343752 n:9213 np:2773 miss:0] ; pred:0.116933 | ├─(pos)─ "capital_gain">=5095.5 [s:0.0125728 n:2773 np:434 miss:0] ; pred:0.272683 | | ├─(pos)─ "occupation" is in [BITMAP] {<OOD>, Prof-specialty, Exec-managerial, Craft-repair, Adm-clerical, Sales, Other-service, Machine-op-inspct, Transport-moving, Handlers-cleaners, ...[2 left]} [s:0.000434532 n:434 np:429 miss:1] ; pred:0.416173 | | | ├─(pos)─ "age">=79.5 [s:0.000449964 n:429 np:5 miss:0] ; pred:0.417414 | | | | ├─(pos)─ pred:0.309737 | | | | └─(neg)─ pred:0.418684 | | | └─(neg)─ pred:0.309737 | | └─(neg)─ "capital_loss">=1782.5 [s:0.0101181 n:2339 np:249 miss:0] ; pred:0.246058 | | ├─(pos)─ "capital_loss">=1989.5 [s:0.00201289 n:249 np:39 miss:0] ; pred:0.406701 | | | ├─(pos)─ pred:0.349312 | | | └─(neg)─ pred:0.417359 | | └─(neg)─ "occupation" is in [BITMAP] {Prof-specialty, Exec-managerial, Sales, Tech-support, Protective-serv} [s:0.0097175 n:2090 np:1688 miss:0] ; pred:0.226919 | | ├─(pos)─ pred:0.253437 | | └─(neg)─ pred:0.11557 | └─(neg)─ "capital_gain">=5095.5 [s:0.0205419 n:6440 np:303 miss:0] ; pred:0.0498685 | ├─(pos)─ "age">=60.5 [s:0.00421502 n:303 np:43 miss:0] ; pred:0.40543 | | ├─(pos)─ "occupation" is in [BITMAP] {Prof-specialty, Exec-managerial, Adm-clerical, Sales, Machine-op-inspct, Transport-moving, Handlers-cleaners} [s:0.0296244 n:43 np:25 miss:0] ; pred:0.317428 | | | ├─(pos)─ pred:0.397934 | | | └─(neg)─ pred:0.205614 | | └─(neg)─ "fnlwgt">=36212.5 [s:1.36643e-16 n:260 np:250 miss:1] ; pred:0.419984 | | ├─(pos)─ pred:0.419984 | | └─(neg)─ pred:0.419984 | └─(neg)─ "occupation" is in [BITMAP] {Prof-specialty, Exec-managerial, Adm-clerical, Sales, Tech-support, Protective-serv} [s:0.0100346 n:6137 np:2334 miss:0] ; pred:0.0323136 | ├─(pos)─ "age">=33.5 [s:0.00939348 n:2334 np:1769 miss:1] ; pred:0.102799 | | ├─(pos)─ pred:0.132992 | | └─(neg)─ pred:0.00826457 | └─(neg)─ "education" is in [BITMAP] {<OOD>, HS-grad, Some-college, Bachelors, Masters, Assoc-voc, Assoc-acdm, Prof-school, Doctorate} [s:0.00478423 n:3803 np:2941 miss:1] ; pred:-0.0109452 | ├─(pos)─ pred:0.00969668 | └─(neg)─ pred:-0.0813718 └─(neg)─ "capital_gain">=7073.5 [s:0.0143125 n:11320 np:199 miss:0] ; pred:-0.0951681 ├─(pos)─ "age">=21.5 [s:0.00807667 n:199 np:194 miss:1] ; pred:0.397823 | ├─(pos)─ "capital_gain">=7565.5 [s:0.00761118 n:194 np:184 miss:0] ; pred:0.405777 | | ├─(pos)─ "capital_gain">=30961.5 [s:0.000242202 n:184 np:20 miss:0] ; pred:0.416988 | | | ├─(pos)─ pred:0.392422 | | | └─(neg)─ pred:0.419984 | | └─(neg)─ "education" is in [BITMAP] {Bachelors, Masters, Assoc-voc, Prof-school} [s:0.16 n:10 np:5 miss:0] ; pred:0.19949 | | ├─(pos)─ pred:0.419984 | | └─(neg)─ pred:-0.0210046 | └─(neg)─ pred:0.0892425 └─(neg)─ "education" is in [BITMAP] {<OOD>, Bachelors, Masters, Prof-school, Doctorate} [s:0.00229611 n:11121 np:2199 miss:1] ; pred:-0.10399 ├─(pos)─ "age">=31.5 [s:0.00725859 n:2199 np:1263 miss:1] ; pred:-0.0507848 | ├─(pos)─ "education" is in [BITMAP] {<OOD>, HS-grad, Some-college, Assoc-voc, 11th, Assoc-acdm, 10th, 7th-8th, Prof-school, 9th, ...[5 left]} [s:0.0110157 n:1263 np:125 miss:1] ; pred:-0.0103552 | | ├─(pos)─ pred:0.16421 | | └─(neg)─ pred:-0.0295298 | └─(neg)─ "capital_loss">=1977 [s:0.00164232 n:936 np:5 miss:0] ; pred:-0.105339 | ├─(pos)─ pred:0.19949 | └─(neg)─ pred:-0.106976 └─(neg)─ "capital_loss">=2218.5 [s:0.000534265 n:8922 np:41 miss:0] ; pred:-0.117103 ├─(pos)─ "fnlwgt">=125450 [s:0.0755454 n:41 np:28 miss:1] ; pred:0.0704198 | ├─(pos)─ pred:-0.0328167 | └─(neg)─ pred:0.292776 └─(neg)─ "hours_per_week">=40.5 [s:0.000447024 n:8881 np:1559 miss:0] ; pred:-0.117969 ├─(pos)─ pred:-0.0927111 └─(neg)─ pred:-0.123347
Make predictions¶
model.predict(ds)
applies a model and returns the predictions as a Numpy array.
model.predict(test_ds)
array([0.01860435, 0.36130956, 0.83858865, ..., 0.03087652, 0.08280362, 0.00970956], dtype=float32)
Methods that consume datasets, such as train
and predict
, support multiple dataset formats such as Pandas DataFrames, dictionaries of lists or Numpy arrays, TensorFlow Datasets, and event file paths!
# Prediction with a dictionary
model.predict({
'age': [39],
'workclass': ['State-gov'],
'fnlwgt': [77516],
'education': ['Bachelors'],
'education_num': [13],
'marital_status': ['Never-married'],
'occupation': ['Adm-clerical'],
'relationship': ['Not-in-family'],
'race': ['White'],
'sex': ['Male'],
'capital_gain': [2174],
'capital_loss': [0],
'hours_per_week': [40],
'native_country': ['United-States'],
'income': ['<=50K'],
})
array([0.01860435], dtype=float32)
Evaluate model¶
While the validation dataset above provides an indication of the model's quality, we also want to evaluate the model on the test dataset.
evaluation = model.evaluate(test_ds)
# Query individual evaluation metrics
print(f"test accuracy: {evaluation.accuracy}")
# Show the full evaluation report
print("Full evaluation report:")
evaluation
test accuracy: 0.8738867847271983 Full evaluation report:
Label \ Pred | <=50K | >50K |
---|---|---|
<=50K | 6962 | 782 |
>50K | 450 | 1575 |
Analyze model¶
With model.analyze(ds)
we can understand how the model behaves. For example, Partial Dependence Plots (PDP) tell us how the model reacts to change of feature values.
model.analyze(test_ds, sampling=0.1)
Variable importances measure the importance of an input feature for a model.
1. "capital_gain" 0.052513 ################ 2. "occupation" 0.020882 ###### 3. "age" 0.015559 #### 4. "relationship" 0.015150 #### 5. "marital_status" 0.014331 #### 6. "capital_loss" 0.014331 #### 7. "education" 0.009110 ## 8. "hours_per_week" 0.006551 # 9. "education_num" 0.005323 # 10. "workclass" 0.003378 11. "race" 0.001024 12. "sex" 0.000921 13. "fnlwgt" 0.000614 14. "native_country" 0.000614
1. "capital_gain" 0.248326 ################ 2. "age" 0.051386 ### 3. "marital_status" 0.046224 ## 4. "capital_loss" 0.044403 ## 5. "occupation" 0.037985 ## 6. "relationship" 0.037500 ## 7. "education" 0.021677 # 8. "hours_per_week" 0.015487 9. "education_num" 0.008588 10. "workclass" 0.003808 11. "sex" 0.003478 12. "fnlwgt" 0.002788 13. "native_country" 0.001978 14. "race" 0.001111
1. "capital_gain" 0.061589 ################ 2. "age" 0.033311 ######## 3. "marital_status" 0.029546 ####### 4. "relationship" 0.020694 ##### 5. "occupation" 0.019686 ##### 6. "capital_loss" 0.014316 ### 7. "education" 0.012061 ## 8. "hours_per_week" 0.009984 ## 9. "education_num" 0.004140 10. "sex" 0.001985 11. "workclass" 0.001577 12. "native_country" 0.001397 13. "fnlwgt" 0.000936 14. "race" 0.000637
1. "capital_gain" 0.248064 ################ 2. "age" 0.051338 ### 3. "marital_status" 0.045982 ## 4. "capital_loss" 0.044387 ## 5. "occupation" 0.037982 ## 6. "relationship" 0.037494 ## 7. "education" 0.021676 # 8. "hours_per_week" 0.015486 9. "education_num" 0.008585 10. "workclass" 0.003812 11. "sex" 0.003477 12. "fnlwgt" 0.002791 13. "native_country" 0.001981 14. "race" 0.001112
1. "age" 0.226642 ################ 2. "occupation" 0.219727 ############# 3. "capital_gain" 0.214876 ############ 4. "education" 0.213746 ########### 5. "marital_status" 0.212739 ########### 6. "relationship" 0.206040 ######### 7. "fnlwgt" 0.203843 ######## 8. "hours_per_week" 0.203735 ######## 9. "capital_loss" 0.196549 ###### 10. "native_country" 0.190548 #### 11. "workclass" 0.187795 ### 12. "education_num" 0.184215 ## 13. "race" 0.180495 14. "sex" 0.177647
1. "age" 26.000000 ################ 2. "capital_gain" 26.000000 ################ 3. "marital_status" 20.000000 ############ 4. "relationship" 17.000000 ########## 5. "capital_loss" 14.000000 ######## 6. "hours_per_week" 14.000000 ######## 7. "education" 12.000000 ####### 8. "fnlwgt" 10.000000 ##### 9. "race" 9.000000 ##### 10. "education_num" 7.000000 ### 11. "sex" 4.000000 # 12. "occupation" 2.000000 13. "workclass" 1.000000 14. "native_country" 1.000000
1. "occupation" 724.000000 ################ 2. "fnlwgt" 513.000000 ########### 3. "age" 483.000000 ########## 4. "education" 464.000000 ########## 5. "hours_per_week" 339.000000 ####### 6. "capital_gain" 326.000000 ###### 7. "native_country" 306.000000 ###### 8. "capital_loss" 297.000000 ###### 9. "relationship" 262.000000 ##### 10. "workclass" 244.000000 ##### 11. "marital_status" 210.000000 #### 12. "education_num" 82.000000 # 13. "sex" 42.000000 14. "race" 21.000000
1. "relationship" 3014.690076 ################ 2. "capital_gain" 2065.521668 ########## 3. "education" 1144.490954 ###### 4. "marital_status" 1111.389695 ##### 5. "occupation" 1094.619502 ##### 6. "education_num" 796.666823 #### 7. "capital_loss" 584.055066 ### 8. "age" 582.288569 ### 9. "hours_per_week" 366.856509 # 10. "native_country" 263.872689 # 11. "fnlwgt" 216.537764 # 12. "workclass" 196.085503 # 13. "sex" 47.217730 14. "race" 5.428727
Benchmark model speed¶
In applications where model speed is critical, we can use model.benchmark(ds)
to evaluate the speed of the model.
model.benchmark(test_ds)
Inference time per example and per cpu core: 0.891 us (microseconds) Estimated over 345 runs over 3.004 seconds. * Measured with the C++ serving API. Check model.to_cpp() for details.
The benchmark measures the speed of the model when using the C++ API. The Python API will be slower due to the overhead of the Python interpreter. If you are not familiar with the C++ API, you can use the model.to_cpp()
method to generate C++ code that you can run to evaluate the model's speed.
print(model.to_cpp())
// Automatically generated code running an Yggdrasil Decision Forests model in // C++. This code was generated with "model.to_cpp()". // // Date of generation: 2023-12-19 15:29:09.343331 // YDF Version: 0.0.8 // // How to use this code: // // 1. Copy this code in a new .h file. // 2. If you use Bazel/Blaze, use the following dependencies: // //third_party/absl/status:statusor // //third_party/absl/strings // //external/ydf_cc/yggdrasil_decision_forests/api:serving // 3. In your existing code, include the .h file and do: // // Load the model (to do only once). // namespace ydf = yggdrasil_decision_forests; // const auto model = ydf::exported_model_123::Load(<path to model>); // // Run the model // predictions = model.Predict(); // 4. By default, the "Predict" function takes no inputs and creates fake // examples. In practice, you want to add your input data as arguments to // "Predict" and call "examples->Set..." functions accordingly. // 4. (Bonus) // Allocate one "examples" and "predictions" per thread and reuse them to // speed-up the inference. // #ifndef YGGDRASIL_DECISION_FORESTS_GENERATED_MODEL_my_model #define YGGDRASIL_DECISION_FORESTS_GENERATED_MODEL_my_model #include <memory> #include <vector> #include "third_party/absl/status/statusor.h" #include "third_party/absl/strings/string_view.h" #include "external/ydf_cc/yggdrasil_decision_forests/api/serving.h" namespace yggdrasil_decision_forests { namespace exported_model_my_model { struct ServingModel { std::vector<float> Predict() const; // Compiled model. std::unique_ptr<serving_api::FastEngine> engine; // Index of the input features of the model. // // Non-owning pointer. The data is owned by the engine. const serving_api::FeaturesDefinition* features; // Number of output predictions for each example. // Equal to 1 for regression, ranking and binary classification with compact // format. Equal to the number of classes for classification. int NumPredictionDimension() const { return engine->NumPredictionDimension(); } // Indexes of the input features. serving_api::NumericalFeatureId feature_age; serving_api::CategoricalFeatureId feature_workclass; serving_api::NumericalFeatureId feature_fnlwgt; serving_api::CategoricalFeatureId feature_education; serving_api::NumericalFeatureId feature_education_num; serving_api::CategoricalFeatureId feature_marital_status; serving_api::CategoricalFeatureId feature_occupation; serving_api::CategoricalFeatureId feature_relationship; serving_api::CategoricalFeatureId feature_race; serving_api::CategoricalFeatureId feature_sex; serving_api::NumericalFeatureId feature_capital_gain; serving_api::NumericalFeatureId feature_capital_loss; serving_api::NumericalFeatureId feature_hours_per_week; serving_api::CategoricalFeatureId feature_native_country; }; // TODO: Pass input feature values to "Predict". inline std::vector<float> ServingModel::Predict() const { // Allocate memory for 2 examples. Alternatively, for speed-sensitive code, // an "examples" object can be allocated for each thread and reused. It is // okay to allocate more examples than needed. const int num_examples = 2; auto examples = engine->AllocateExamples(num_examples); // Set all the values to be missing. The values may then be overridden by the // "Set*" methods. If all the values are set with "Set*" methods, // "FillMissing" can be skipped. examples->FillMissing(*features); // Example #0 examples->SetNumerical(/*example_idx=*/0, feature_age, 1.f, *features); examples->SetCategorical(/*example_idx=*/0, feature_workclass, "A", *features); examples->SetNumerical(/*example_idx=*/0, feature_fnlwgt, 1.f, *features); examples->SetCategorical(/*example_idx=*/0, feature_education, "A", *features); examples->SetNumerical(/*example_idx=*/0, feature_education_num, 1.f, *features); examples->SetCategorical(/*example_idx=*/0, feature_marital_status, "A", *features); examples->SetCategorical(/*example_idx=*/0, feature_occupation, "A", *features); examples->SetCategorical(/*example_idx=*/0, feature_relationship, "A", *features); examples->SetCategorical(/*example_idx=*/0, feature_race, "A", *features); examples->SetCategorical(/*example_idx=*/0, feature_sex, "A", *features); examples->SetNumerical(/*example_idx=*/0, feature_capital_gain, 1.f, *features); examples->SetNumerical(/*example_idx=*/0, feature_capital_loss, 1.f, *features); examples->SetNumerical(/*example_idx=*/0, feature_hours_per_week, 1.f, *features); examples->SetCategorical(/*example_idx=*/0, feature_native_country, "A", *features); // Example #1 examples->SetNumerical(/*example_idx=*/1, feature_age, 2.f, *features); examples->SetCategorical(/*example_idx=*/1, feature_workclass, "B", *features); examples->SetNumerical(/*example_idx=*/1, feature_fnlwgt, 2.f, *features); examples->SetCategorical(/*example_idx=*/1, feature_education, "B", *features); examples->SetNumerical(/*example_idx=*/1, feature_education_num, 2.f, *features); examples->SetCategorical(/*example_idx=*/1, feature_marital_status, "B", *features); examples->SetCategorical(/*example_idx=*/1, feature_occupation, "B", *features); examples->SetCategorical(/*example_idx=*/1, feature_relationship, "B", *features); examples->SetCategorical(/*example_idx=*/1, feature_race, "B", *features); examples->SetCategorical(/*example_idx=*/1, feature_sex, "B", *features); examples->SetNumerical(/*example_idx=*/1, feature_capital_gain, 2.f, *features); examples->SetNumerical(/*example_idx=*/1, feature_capital_loss, 2.f, *features); examples->SetNumerical(/*example_idx=*/1, feature_hours_per_week, 2.f, *features); examples->SetCategorical(/*example_idx=*/1, feature_native_country, "B", *features); // Run the model on the two examples. // // For speed-sensitive code, reuse the same predictions. std::vector<float> predictions; engine->Predict(*examples, num_examples, &predictions); return predictions; } inline absl::StatusOr<ServingModel> Load(absl::string_view path) { ServingModel m; // Load the model ASSIGN_OR_RETURN(auto model, serving_api::LoadModel(path)); // Compile the model into an inference engine. ASSIGN_OR_RETURN(m.engine, model->BuildFastEngine()); // Index the input features of the model. m.features = &m.engine->features(); // Index the input features. ASSIGN_OR_RETURN(m.feature_age, m.features->GetNumericalFeatureId("age")); ASSIGN_OR_RETURN(m.feature_workclass, m.features->GetCategoricalFeatureId("workclass")); ASSIGN_OR_RETURN(m.feature_fnlwgt, m.features->GetNumericalFeatureId("fnlwgt")); ASSIGN_OR_RETURN(m.feature_education, m.features->GetCategoricalFeatureId("education")); ASSIGN_OR_RETURN(m.feature_education_num, m.features->GetNumericalFeatureId("education_num")); ASSIGN_OR_RETURN(m.feature_marital_status, m.features->GetCategoricalFeatureId("marital_status")); ASSIGN_OR_RETURN(m.feature_occupation, m.features->GetCategoricalFeatureId("occupation")); ASSIGN_OR_RETURN(m.feature_relationship, m.features->GetCategoricalFeatureId("relationship")); ASSIGN_OR_RETURN(m.feature_race, m.features->GetCategoricalFeatureId("race")); ASSIGN_OR_RETURN(m.feature_sex, m.features->GetCategoricalFeatureId("sex")); ASSIGN_OR_RETURN(m.feature_capital_gain, m.features->GetNumericalFeatureId("capital_gain")); ASSIGN_OR_RETURN(m.feature_capital_loss, m.features->GetNumericalFeatureId("capital_loss")); ASSIGN_OR_RETURN(m.feature_hours_per_week, m.features->GetNumericalFeatureId("hours_per_week")); ASSIGN_OR_RETURN(m.feature_native_country, m.features->GetCategoricalFeatureId("native_country")); return m; } } // namespace exported_model_my_model } // namespace yggdrasil_decision_forests #endif // YGGDRASIL_DECISION_FORESTS_GENERATED_MODEL_my_model
Save model¶
Finally, we use the same model for later use.
model.save("/tmp/my_model")
So we can load the model with:
loaded_model = ydf.load_model("/tmp/my_model")
print(f"This is a {loaded_model.name()} model.")
This is a GRADIENT_BOOSTED_TREES model.
Conclusion¶
This is it. You know the basic capabilities of YDF 😊.
To learn more about YDF, check the other tutorials on ydf.readthedocs.io. For instance, learn how to:
- Learn to train ranking, regression or uplifting models with the
task
argument. - Measure distance and find the nearest neighbor between examples with
model.distance
. - Enforce monotonic constraints on your features with the
features
argument. - Run the models in a webpage in JavaScript with
model.to_javascript()
. - Convert the model into a TensorFlow SavedModel and run it in TensorFlow Serving with
model.to_tensorflow_saved_model()
. - Train a model on billions of training examples using distributed training computation.