pip install ydf -U
What is model tuning?¶
Model tuning, also known as automated model hyperparameter optimization or AutoML, involves finding the optimal hyperparameters for a learner to maximize the performance of a model. YDF supports model tuning out-of-the-box.
YDF model tuning has two modes. A user can either manually specify the hyperparameters to optimize and their candidate values, or use a pre-configured tuning. The second option is simpler, while the first option gives you more control. We will demonstrate both options in this tutorial.
Tuning can be done on a single machine or across multiple machines using distributed training. This tutorial focuses on tuning on a single machine. Local tuning is simple to set up and can produce excellent results on small datasets.
Distributed model tuning¶
Distributed training tuning can be advantageous for models that take a long time to train or have a large hyperparameter search space. Distributed tuning requires configuring workers and specifying the workers
constructor argument of the learner. After the workers are set up, the model tuning strategy is the same as for tuning on a local machine. For more information, see the distributed training tutorial.
Download dataset¶
We use the adult dataset.
import ydf # Yggdrasil Decision Forests
import pandas as pd # We use Pandas to load small datasets
# Download a classification dataset and load it as a Pandas DataFrame.
ds_path = "https://raw.githubusercontent.com/google/yggdrasil-decision-forests/main/yggdrasil_decision_forests/test_data/dataset"
train_ds = pd.read_csv(f"{ds_path}/adult_train.csv")
test_ds = pd.read_csv(f"{ds_path}/adult_test.csv")
# Print the first 5 training examples
train_ds.head(5)
age | workclass | fnlwgt | education | education_num | marital_status | occupation | relationship | race | sex | capital_gain | capital_loss | hours_per_week | native_country | income | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 44 | Private | 228057 | 7th-8th | 4 | Married-civ-spouse | Machine-op-inspct | Wife | White | Female | 0 | 0 | 40 | Dominican-Republic | <=50K |
1 | 20 | Private | 299047 | Some-college | 10 | Never-married | Other-service | Not-in-family | White | Female | 0 | 0 | 20 | United-States | <=50K |
2 | 40 | Private | 342164 | HS-grad | 9 | Separated | Adm-clerical | Unmarried | White | Female | 0 | 0 | 37 | United-States | <=50K |
3 | 30 | Private | 361742 | Some-college | 10 | Married-civ-spouse | Exec-managerial | Husband | White | Male | 0 | 0 | 50 | United-States | <=50K |
4 | 67 | Self-emp-inc | 171564 | HS-grad | 9 | Married-civ-spouse | Prof-specialty | Wife | White | Female | 20051 | 0 | 30 | England | >50K |
Local tuning with manually set hyper-parameters¶
The hyper-parameters of a learner are accessible in the API and on the hyper-parameter page. The guide How to improve a model also provides some recommendations on the hyper-parameters that are most impactful to optimize. In this example, we train a gradient boosted trees model and optimize the following hyper-parameters: shrinkage
, subsample
, and max_depth
.
The tuning objective is automatically selected for the model. For instance, for GradientBoostedTreesLearner
used in this example, the loss is minimized.
Let's configure the tuner:
tuner = ydf.RandomSearchTuner(num_trials=50)
tuner.choice("shrinkage", [0.2, 0.1, 0.05])
tuner.choice("subsample", [1.0, 0.9, 0.8])
tuner.choice("max_depth", [3, 4, 5, 6])
<ydf.learner.tuner.SearchSpace at 0x7f3eb4372310>
We create a learner using this tuner, and train a model:
Note: Parameters that are not tuned can be specified directly on the learner.
Note: To print the tuning logs during tuning, enable logging with ydf.verbose(2)
.
learner = ydf.GradientBoostedTreesLearner(
label="income",
num_trees=100, # Used for all the trials.
tuner=tuner,
)
model =learner.train(train_ds)
Train model on 22792 examples Model trained in 0:00:03.998356
The model description includes the tuning logs, which is a list of the hyper-parameters that were tested and their scores, are available in the tuning
tab of the model description.
model.describe()
Task : CLASSIFICATION
Label : income
Features (14) : age workclass fnlwgt education education_num marital_status occupation relationship race sex capital_gain capital_loss hours_per_week native_country
Weights : None
Trained with tuner : Yes
Model size : 543 kB
Number of records: 22792 Number of columns: 15 Number of columns by type: CATEGORICAL: 9 (60%) NUMERICAL: 6 (40%) Columns: CATEGORICAL: 9 (60%) 0: "income" CATEGORICAL has-dict vocab-size:3 zero-ood-items most-frequent:"<=50K" 17308 (75.9389%) 2: "workclass" CATEGORICAL num-nas:1257 (5.51509%) has-dict vocab-size:8 num-oods:3 (0.0139308%) most-frequent:"Private" 15879 (73.7358%) 4: "education" CATEGORICAL has-dict vocab-size:17 zero-ood-items most-frequent:"HS-grad" 7340 (32.2043%) 6: "marital_status" CATEGORICAL has-dict vocab-size:8 zero-ood-items most-frequent:"Married-civ-spouse" 10431 (45.7661%) 7: "occupation" CATEGORICAL num-nas:1260 (5.52826%) has-dict vocab-size:14 num-oods:4 (0.018577%) most-frequent:"Prof-specialty" 2870 (13.329%) 8: "relationship" CATEGORICAL has-dict vocab-size:7 zero-ood-items most-frequent:"Husband" 9191 (40.3256%) 9: "race" CATEGORICAL has-dict vocab-size:6 zero-ood-items most-frequent:"White" 19467 (85.4115%) 10: "sex" CATEGORICAL has-dict vocab-size:3 zero-ood-items most-frequent:"Male" 15165 (66.5365%) 14: "native_country" CATEGORICAL num-nas:407 (1.78571%) has-dict vocab-size:41 num-oods:1 (0.00446728%) most-frequent:"United-States" 20436 (91.2933%) NUMERICAL: 6 (40%) 1: "age" NUMERICAL mean:38.6153 min:17 max:90 sd:13.661 3: "fnlwgt" NUMERICAL mean:189879 min:12285 max:1.4847e+06 sd:106423 5: "education_num" NUMERICAL mean:10.0927 min:1 max:16 sd:2.56427 11: "capital_gain" NUMERICAL mean:1081.9 min:0 max:99999 sd:7509.48 12: "capital_loss" NUMERICAL mean:87.2806 min:0 max:4356 sd:403.01 13: "hours_per_week" NUMERICAL mean:40.3955 min:1 max:99 sd:12.249 Terminology: nas: Number of non-available (i.e. missing) values. ood: Out of dictionary. manually-defined: Attribute whose type is manually defined by the user, i.e., the type was not automatically inferred. tokenized: The attribute value is obtained through tokenization. has-dict: The attribute is attached to a string dictionary e.g. a categorical attribute stored as a string. vocab-size: Number of unique values.
A tuner automatically selects the hyper-parameters of a learner.
trial | score | duration | shrinkage | subsample | max_depth |
---|---|---|---|---|---|
16 | -0.574861 | 2.49348 | 0.2 | 1 | 5 |
31 | -0.576405 | 3.53616 | 0.2 | 1 | 6 |
15 | -0.577211 | 2.4727 | 0.1 | 1 | 5 |
33 | -0.578941 | 3.69053 | 0.2 | 0.9 | 5 |
32 | -0.579071 | 3.54803 | 0.2 | 0.9 | 6 |
35 | -0.579637 | 3.99118 | 0.1 | 1 | 6 |
19 | -0.581703 | 2.68832 | 0.2 | 0.8 | 6 |
34 | -0.582941 | 3.90171 | 0.1 | 0.8 | 6 |
14 | -0.583348 | 2.46785 | 0.2 | 0.8 | 5 |
27 | -0.583466 | 3.23896 | 0.2 | 0.9 | 4 |
10 | -0.58463 | 2.14364 | 0.2 | 1 | 4 |
22 | -0.584824 | 2.97681 | 0.1 | 0.9 | 6 |
13 | -0.585809 | 2.46436 | 0.1 | 0.9 | 5 |
12 | -0.587067 | 2.29765 | 0.1 | 0.8 | 5 |
8 | -0.590813 | 1.97632 | 0.2 | 0.8 | 4 |
24 | -0.593991 | 3.0293 | 0.05 | 1 | 6 |
9 | -0.595175 | 2.14037 | 0.1 | 1 | 4 |
21 | -0.596592 | 2.91333 | 0.05 | 0.8 | 6 |
28 | -0.597159 | 3.2767 | 0.1 | 0.9 | 4 |
20 | -0.597244 | 2.90384 | 0.05 | 0.9 | 6 |
6 | -0.597766 | 1.96352 | 0.1 | 0.8 | 4 |
5 | -0.603554 | 1.71404 | 0.2 | 1 | 3 |
23 | -0.60517 | 3.01335 | 0.2 | 0.9 | 3 |
18 | -0.605849 | 2.54463 | 0.05 | 0.9 | 5 |
0 | -0.606706 | 1.49037 | 0.2 | 0.8 | 3 |
17 | -0.607283 | 2.511 | 0.05 | 0.8 | 5 |
30 | -0.608091 | 3.47695 | 0.05 | 1 | 5 |
25 | -0.619956 | 3.17843 | 0.1 | 0.9 | 3 |
3 | -0.620752 | 1.63833 | 0.1 | 0.8 | 3 |
4 | -0.621349 | 1.70712 | 0.1 | 1 | 3 |
7 | -0.625488 | 1.96705 | 0.05 | 0.8 | 4 |
29 | -0.626953 | 3.43528 | 0.05 | 0.9 | 4 |
11 | -0.62982 | 2.16092 | 0.05 | 1 | 4 |
1 | -0.656424 | 1.57613 | 0.05 | 0.8 | 3 |
26 | -0.656732 | 3.20212 | 0.05 | 1 | 3 |
2 | -0.656747 | 1.62633 | 0.05 | 0.9 | 3 |
The following evaluation is computed on the validation or out-of-bag dataset.
Task: CLASSIFICATION Label: income Loss (BINOMIAL_LOG_LIKELIHOOD): 0.574861 Accuracy: 0.87251 CI95[W][0 1] ErrorRate: : 0.12749 Confusion Table: truth\prediction <=50K >50K <=50K 1570 94 >50K 194 401 Total: 2259
Variable importances measure the importance of an input feature for a model.
1. "age" 0.257622 ################ 2. "capital_gain" 0.249047 ############# 3. "relationship" 0.244032 ########### 4. "occupation" 0.242881 ########### 5. "hours_per_week" 0.238530 ########## 6. "education" 0.237441 ######### 7. "marital_status" 0.234935 ######## 8. "capital_loss" 0.231145 ####### 9. "fnlwgt" 0.226059 ###### 10. "native_country" 0.225767 ###### 11. "workclass" 0.220718 #### 12. "education_num" 0.219033 #### 13. "sex" 0.211384 # 14. "race" 0.206124
1. "capital_gain" 11.000000 ################ 2. "age" 10.000000 ############## 3. "hours_per_week" 10.000000 ############## 4. "relationship" 9.000000 ############ 5. "marital_status" 7.000000 ######### 6. "education" 6.000000 ######## 7. "capital_loss" 6.000000 ######## 8. "fnlwgt" 5.000000 ###### 9. "workclass" 3.000000 ### 10. "education_num" 3.000000 ### 11. "sex" 3.000000 ### 12. "occupation" 1.000000 13. "race" 1.000000
1. "occupation" 144.000000 ################ 2. "age" 121.000000 ############# 3. "education" 113.000000 ############ 4. "capital_gain" 111.000000 ############ 5. "capital_loss" 90.000000 ######### 6. "native_country" 87.000000 ######### 7. "fnlwgt" 84.000000 ######### 8. "relationship" 73.000000 ####### 9. "marital_status" 68.000000 ####### 10. "hours_per_week" 64.000000 ###### 11. "workclass" 49.000000 ##### 12. "education_num" 28.000000 ## 13. "sex" 14.000000 # 14. "race" 5.000000
1. "relationship" 1675.422986 ################ 2. "capital_gain" 1040.150118 ######### 3. "education_num" 687.196583 ###### 4. "occupation" 526.056194 ##### 5. "marital_status" 469.469421 #### 6. "age" 289.979275 ## 7. "capital_loss" 281.277707 ## 8. "education" 259.256109 ## 9. "hours_per_week" 181.939375 # 10. "native_country" 108.750643 # 11. "workclass" 64.136268 12. "fnlwgt" 46.873309 13. "sex" 30.074515 14. "race" 2.153583
Those variable importances are computed during training. More, and possibly more informative, variable importances are available when analyzing a model on a test dataset.
Only printing the first tree.
Tree #0: "relationship" is in [BITMAP] {<OOD>, Husband, Wife} [s:0.036623 n:20533 np:9213 miss:1] ; pred:-8.31766e-09 ├─(pos)─ "education_num">=12.5 [s:0.0343752 n:9213 np:2773 miss:0] ; pred:0.233866 | ├─(pos)─ "capital_gain">=5095.5 [s:0.0125728 n:2773 np:434 miss:0] ; pred:0.545366 | | ├─(pos)─ "occupation" is in [BITMAP] {<OOD>, Prof-specialty, Exec-managerial, Craft-repair, Adm-clerical, Sales, Other-service, Machine-op-inspct, Transport-moving, Handlers-cleaners, ...[2 left]} [s:0.000434532 n:434 np:429 miss:1] ; pred:0.832346 | | | ├─(pos)─ pred:0.834828 | | | └─(neg)─ pred:0.619473 | | └─(neg)─ "capital_loss">=1782.5 [s:0.0101181 n:2339 np:249 miss:0] ; pred:0.492116 | | ├─(pos)─ pred:0.813402 | | └─(neg)─ pred:0.453839 | └─(neg)─ "capital_gain">=5095.5 [s:0.0205419 n:6440 np:303 miss:0] ; pred:0.0997371 | ├─(pos)─ "age">=60.5 [s:0.00421502 n:303 np:43 miss:0] ; pred:0.810859 | | ├─(pos)─ pred:0.634856 | | └─(neg)─ pred:0.839967 | └─(neg)─ "occupation" is in [BITMAP] {Prof-specialty, Exec-managerial, Adm-clerical, Sales, Tech-support, Protective-serv} [s:0.0100346 n:6137 np:2334 miss:0] ; pred:0.0646271 | ├─(pos)─ pred:0.205598 | └─(neg)─ pred:-0.0218904 └─(neg)─ "capital_gain">=7073.5 [s:0.0143125 n:11320 np:199 miss:0] ; pred:-0.190336 ├─(pos)─ "age">=21.5 [s:0.00807667 n:199 np:194 miss:1] ; pred:0.795647 | ├─(pos)─ "capital_gain">=7565.5 [s:0.00761118 n:194 np:184 miss:0] ; pred:0.811553 | | ├─(pos)─ pred:0.833976 | | └─(neg)─ pred:0.398979 | └─(neg)─ pred:0.178485 └─(neg)─ "education" is in [BITMAP] {<OOD>, Bachelors, Masters, Prof-school, Doctorate} [s:0.00229611 n:11121 np:2199 miss:1] ; pred:-0.207979 ├─(pos)─ "age">=31.5 [s:0.00725859 n:2199 np:1263 miss:1] ; pred:-0.10157 | ├─(pos)─ pred:-0.0207104 | └─(neg)─ pred:-0.210678 └─(neg)─ "capital_loss">=2218.5 [s:0.000534265 n:8922 np:41 miss:0] ; pred:-0.234206 ├─(pos)─ pred:0.14084 └─(neg)─ pred:-0.235938
The model can then be evaluated as usual.
model.evaluate(test_ds)
Label \ Pred | <=50K | >50K |
---|---|---|
<=50K | 6974 | 438 |
>50K | 781 | 1576 |