With JAX¶
About this tutorial¶
JAX is a machine learning library to train neural network models. While decision forests trained by YDF are different from neural networks, YDF and JAX can be combined to create powerful hybrid models.
This tutorial is divided into two parts. First, we show how to convert a YDF model into a JAX model, and how to save the resulting model as a SavedModel using jax2tf
.
Second, we show how YDF and JAX can be combined to solve the distribution shift problem: We train a YDF model, convert it to a JAX model, finetune it using JAX, and convert it back to a YDF model.
Setup¶
# Install dependencies
!pip install ydf -U -q
!pip install tensorflow -U -q
!pip install optax pandas numpy -U -q
!pip install jax[cpu] -U
# OR
# !pip install jax[cuda12] -U -q
# See https://jax.readthedocs.io/en/latest/installation.html for JAX variations.
import tempfile
import jax
from jax.experimental import jax2tf # To export JAX model to SavedModel
import optax # To finetune YDF+JAX models
import pandas as pd # We use Pandas to load small datasets
import tensorflow as tf # To create SavedModels
import ydf # Yggdrasil Decision Forests
Convert YDF model into a JAX function¶
In this section, we train a YDF model on the Adult dataset, convert it into a JAX function, and demonstrate various operations.
First let's download a binary classification dataset.
ds_path = "https://raw.githubusercontent.com/google/yggdrasil-decision-forests/main/yggdrasil_decision_forests/test_data/dataset"
# Download and load the dataset as Pandas DataFrames
train_ds = pd.read_csv(f"{ds_path}/adult_train.csv")
test_ds = pd.read_csv(f"{ds_path}/adult_test.csv")
label = "income"
# Print the first 5 training examples
train_ds.head(5)
age | workclass | fnlwgt | education | education_num | marital_status | occupation | relationship | race | sex | capital_gain | capital_loss | hours_per_week | native_country | income | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 44 | Private | 228057 | 7th-8th | 4 | Married-civ-spouse | Machine-op-inspct | Wife | White | Female | 0 | 0 | 40 | Dominican-Republic | <=50K |
1 | 20 | Private | 299047 | Some-college | 10 | Never-married | Other-service | Not-in-family | White | Female | 0 | 0 | 20 | United-States | <=50K |
2 | 40 | Private | 342164 | HS-grad | 9 | Separated | Adm-clerical | Unmarried | White | Female | 0 | 0 | 37 | United-States | <=50K |
3 | 30 | Private | 361742 | Some-college | 10 | Married-civ-spouse | Exec-managerial | Husband | White | Male | 0 | 0 | 50 | United-States | <=50K |
4 | 67 | Self-emp-inc | 171564 | HS-grad | 9 | Married-civ-spouse | Prof-specialty | Wife | White | Female | 20051 | 0 | 30 | England | >50K |
First, we train a YDF model on the dataset.
learner = ydf.GradientBoostedTreesLearner(label=label)
model = learner.train(train_ds)
Train model on 22792 examples Model trained in 0:00:02.277830
We convert the YDF model into a JAX function.
jax_model = model.to_jax_function()
The jax_model
object contains three fields.
predict
: A JAX function making predictions.encoder
: A callable class to prepare examples forpredict
. Since JAX does not support string values, categorical string input features have to be prepared before callingpredict
.params
: A optional dictionary of Jax Arrays defining the differentiable parameters of the model. By default,params
is None andpredict
does not except any parameters. We show how to useparams
in the second section.
We generate predictions for the first 5 examples in the test set.
First, we select some examples and encode them.
# Select the first 5 examples from the Pandas Dataframe and remove the labels.
selected_examples = test_ds[:5].drop(model.label(), axis=1)
# Encode the examples into a dictionary of JAX arrays.
jax_selected_examples = jax_model.encoder(selected_examples)
jax_selected_examples
{'age': Array([39, 40, 40, 35, 23], dtype=int32), 'workclass': Array([4, 1, 1, 6, 3], dtype=int32), 'fnlwgt': Array([ 77516, 121772, 193524, 76845, 190709], dtype=int32), 'education': Array([ 3, 5, 13, 11, 7], dtype=int32), 'education_num': Array([13, 11, 16, 5, 12], dtype=int32), 'marital_status': Array([2, 1, 1, 1, 2], dtype=int32), 'occupation': Array([ 4, 3, 1, 10, 12], dtype=int32), 'relationship': Array([2, 1, 1, 1, 2], dtype=int32), 'race': Array([1, 3, 1, 2, 1], dtype=int32), 'sex': Array([1, 1, 1, 1, 1], dtype=int32), 'capital_gain': Array([2174, 0, 0, 0, 0], dtype=int32), 'capital_loss': Array([0, 0, 0, 0, 0], dtype=int32), 'hours_per_week': Array([40, 40, 60, 40, 52], dtype=int32), 'native_country': Array([1, 0, 1, 1, 1], dtype=int32)}
Then, we generate the predictions.
jax_predictions = jax_model.predict(jax_selected_examples)
jax_predictions
Array([0.01860434, 0.36130956, 0.83858865, 0.04385566, 0.02917648], dtype=float32)
Note that the predictions of the JAX function are equal to the predictions of the YDF model (modulo float rouding errors).
model.predict(selected_examples)
array([0.01860435, 0.36130956, 0.83858865, 0.04385567, 0.02917649], dtype=float32)
JAX does not define a model serialization format e.g. a way to save a model on disk. Instead, to save a JAX model for serving, it is common to export it as a SavedModel.
# Create a TF module with the model.
tf_model = tf.Module()
tf_model.predict = tf.function(
jax2tf.convert(jax_model.predict, with_gradient=False),
jit_compile=True,
autograph=False,
)
# Check the predictions of the TF module.
tf_selected_examples = {
k: tf.constant(v) for k, v in jax_selected_examples.items()
}
tf_predictions = tf_model.predict(tf_selected_examples)
tf_predictions
<tf.Tensor: shape=(5,), dtype=float32, numpy= array([0.01860434, 0.36130956, 0.83858865, 0.04385566, 0.02917648], dtype=float32)>
# Save the TF module to file.
with tempfile.TemporaryDirectory() as tempdir:
tf.saved_model.save(tf_model, tempdir)
INFO:tensorflow:Assets written to: /tmp/tmp90flesgr/assets
INFO:tensorflow:Assets written to: /tmp/tmp90flesgr/assets
INFO: YDF's to_tensorflow_saved_model
function allows to directly create a SavedModel model. This approach results in faster models, but it requires the installation of TensorFlow Decision Forests.
try:
with tempfile.TemporaryDirectory() as tempdir:
# Save the YDF model to a SavedModel directly.
model.to_tensorflow_saved_model(tempdir, mode="tf")
except Exception as e:
print("Could not save YDF model to SavedModel with to_tensorflow_saved_model")
[INFO 24-06-14 14:31:56.6553 CEST kernel.cc:1233] Loading model from path /tmp/tmp71lnhoy9/tmp83xu8mjt/ with prefix e57777e0_ [INFO 24-06-14 14:31:56.6795 CEST quick_scorer_extended.cc:911] The binary was compiled without AVX2 support, but your CPU supports it. Enable it for faster model inference. [INFO 24-06-14 14:31:56.6803 CEST abstract_model.cc:1362] Engine "GradientBoostedTreesQuickScorerExtended" built [INFO 24-06-14 14:31:56.6803 CEST kernel.cc:1061] Use fast generic engine
INFO:tensorflow:Assets written to: /tmp/tmpi0fp69xz/assets
INFO:tensorflow:Assets written to: /tmp/tmpi0fp69xz/assets
Fine tune a YDF model with JAX¶
A distribution shift problem occurs when the examples of interest (serving examples) follow a different distribution than the training dataset. As an example, distribution shift occurs in hospitals when training a model on data acquired by different devices. Although datasets from different devices should be compatible, subtle differences between them cause a model trained on one dataset to perform poorly on another. For instance, a machine learning model trained to detect tumors on images captured by a device might not work effectively on images captured by a device from another brand. Distribution shifts are also common in dynamic systems that change overtime (e.g., user behaviors).
In this section, we solve a distribution shift issue using finetuning. For that, we use the Adult dataset with a twist. We assume that only people with "relationship=Wife," are of interest. However, only 5% of the people are in this category so we have few training examples.
We will first observe that training only on relationship=Wife
examples or training on all available examples does not produce the best model. Instead, we will train a YDF model on all examples, finetuned it with JAX on the relationship=Wife
examples, and observe that this finetune model perform better. Finally, the finetuned JAX model will be converted back into a YDF model and analyzed using YDF tools.
INFO: This section assumes you are familiar with JAX and Orbax.
First, let's print the distribution of relationship
in the test examples. Our objective is optimize the quality of the model on the 483 relationship == Wife
examples.
test_ds["relationship"].value_counts()
relationship Husband 4002 Not-in-family 2505 Own-child 1521 Unmarried 948 Wife 483 Other-relative 310 Name: count, dtype: int64
We divide the dataset in two groups: Group A contains the relationship != Wife
examples and group B contains the relationship == Wife
examples.
def is_group_B(ds):
return ds["relationship"] == "Wife"
train_ds_group_A = train_ds[~is_group_B(train_ds)]
test_ds_group_A = test_ds[~is_group_B(test_ds)]
train_ds_group_B = train_ds[is_group_B(train_ds)]
test_ds_group_B = test_ds[is_group_B(test_ds)]
print("Number of examples per group")
print("\tTrain Group A:", len(train_ds_group_A))
print("\tTest Group A:", len(test_ds_group_A))
print("\tTrain Group B:", len(train_ds_group_B))
print("\tTest Group B:", len(test_ds_group_B))
Number of examples per group Train Group A: 21707 Test Group A: 9286 Train Group B: 1085 Test Group B: 483
Note that group A contains more examples than group B, but what we care are the test examples in group B.
Let's train and evaluate three models on different combinations of group A and B. Those will be our baselines.
# Train model on group A
model_group_A = ydf.GradientBoostedTreesLearner(label=label).train(
train_ds_group_A, verbose=0
)
# Train model on group B
model_group_B = ydf.GradientBoostedTreesLearner(label=label).train(
train_ds_group_B, verbose=0
)
# Train model on group A + B
model_group_AB = ydf.GradientBoostedTreesLearner(label=label).train(
train_ds, verbose=0
)
# Evaluate the models on group B
accuracy_test_B_model_A = model_group_A.evaluate(test_ds_group_B).accuracy
accuracy_test_B_model_B = model_group_B.evaluate(test_ds_group_B).accuracy
accuracy_test_B_model_AB = model_group_AB.evaluate(test_ds_group_B).accuracy
print("Accuracy on B, model trained on A:", accuracy_test_B_model_A)
print("Accuracy on B, model trained on B:", accuracy_test_B_model_B)
print("Accuracy on B, model trained on A+B:", accuracy_test_B_model_AB)
Accuracy on B, model trained on A: 0.7204968944099379 Accuracy on B, model trained on B: 0.7329192546583851 Accuracy on B, model trained on A+B: 0.7556935817805382
The model trained on both group A and B is the one performing best on group B. Can we do better?
Let's convert the model trained on A+B into a JAX function.
jax_model_group_AB = model_group_AB.to_jax_function(
apply_activation=False,
leaves_as_params=True,
)
jax_model_group_AB.params
{'leaf_values': Array([-0.1233467 , -0.0927111 , 0.2927755 , ..., 0.05464426, 0.12556875, -0.11374608], dtype=float32), 'initial_predictions': Array([-1.1630996], dtype=float32)}
Note that:
apply_activation=True
removes the activation function from the model. This allows for the model loss to be computed on logits rather than probabilities which make finetuning more stable numerically.leaves_as_params=True
specifies that the leave values are exported as model parameters inparams
. This is necessary to finetune the model.
To finetune the model, we need to generate batches of examples. The following block generate such batches.
def get_num_examples(ds):
return len(next(iter(ds.values())))
def prepare_dataset(ds, jax_model, batch=100):
ds = ds.copy()
# Make the label boolean
ds[label] = ds[label] == ">50K"
# Encode the input features
encoded_ds = jax_model.encoder(ds)
# Yield batches of examples
n = get_num_examples(encoded_ds)
i = 0
while i < n:
begin_idx = i
end_idx = min(i + batch, n)
yield {k: v[begin_idx:end_idx] for k, v in encoded_ds.items()}
i += batch
# Example of utilisation of "prepare_dataset".
for examples in prepare_dataset(train_ds_group_B, jax_model_group_AB, batch=4):
print(examples)
break # We only print the first batch
{'age': Array([44, 67, 26, 30], dtype=int32), 'workclass': Array([1, 5, 0, 1], dtype=int32), 'fnlwgt': Array([228057, 171564, 167835, 118551], dtype=int32), 'education': Array([9, 1, 3, 3], dtype=int32), 'education_num': Array([ 4, 9, 13, 13], dtype=int32), 'marital_status': Array([1, 1, 1, 1], dtype=int32), 'occupation': Array([ 7, 1, 0, 11], dtype=int32), 'relationship': Array([5, 5, 5, 5], dtype=int32), 'race': Array([1, 1, 1, 1], dtype=int32), 'sex': Array([2, 2, 2, 2], dtype=int32), 'capital_gain': Array([ 0, 20051, 0, 0], dtype=int32), 'capital_loss': Array([0, 0, 0, 0], dtype=int32), 'hours_per_week': Array([40, 30, 20, 16], dtype=int32), 'native_country': Array([12, 10, 1, 1], dtype=int32), 'income': Array([False, True, False, True], dtype=bool)}
Let's define utilities to compute and print the loss and accuracy of the model.
@jax.jit
def compute_accuracy(params, examples, logit=True):
examples = examples.copy()
labels = examples.pop(model.label())
predictions = jax_model_group_AB.predict(examples, params)
return ((predictions >= 0.0) == labels).mean()
@jax.jit
def compute_loss(params, examples):
examples = examples.copy()
labels = examples.pop(model.label())
logits = jax_model_group_AB.predict(examples, params)
return optax.sigmoid_binary_cross_entropy(logits, labels).mean()
def compute_metric(metric_fn, ds):
sum_metrics = 0
num_examples = 0
for examples in prepare_dataset(ds, jax_model_group_AB):
n = get_num_examples(examples)
sum_metrics += n * metric_fn(jax_model_group_AB.params, examples)
num_examples += n
return float(sum_metrics / num_examples)
def print_logs(stage):
train_accuracy = compute_metric(compute_accuracy, train_ds_group_B)
train_loss = compute_metric(compute_loss, train_ds_group_B)
test_accuracy = compute_metric(compute_accuracy, test_ds_group_B)
test_loss = compute_metric(compute_loss, test_ds_group_B)
print(
f"stage:{stage:10} "
f"test-accuracy:{test_accuracy:.5f} test-loss:{test_loss:.5f} "
f"train-accuracy:{train_accuracy:.5f} train-loss:{train_loss:.5f}"
)
# Metrics of the model before training.
print_logs("initial")
stage:initial test-accuracy:0.75569 test-loss:0.47798 train-accuracy:0.83963 train-loss:0.37099
Following is the train training loop.
optimizer = optax.adam(0.001)
@jax.jit
def train_step(opt_state, mdl_state, examples):
loss, grads = jax.value_and_grad(compute_loss)(mdl_state, examples)
updates, opt_state = optimizer.update(grads, opt_state)
mdl_state = optax.apply_updates(mdl_state, updates)
return opt_state, mdl_state, loss
opt_state = optimizer.init(jax_model_group_AB.params)
for epoch_idx in range(10):
print_logs(f"epoch_{epoch_idx}")
for examples in prepare_dataset(train_ds_group_B, jax_model_group_AB):
opt_state, jax_model_group_AB.params, _ = train_step(
opt_state, jax_model_group_AB.params, examples
)
print_logs("final")
stage:epoch_0 test-accuracy:0.75569 test-loss:0.47798 train-accuracy:0.83963 train-loss:0.37099 stage:epoch_1 test-accuracy:0.75155 test-loss:0.48035 train-accuracy:0.84424 train-loss:0.36520 stage:epoch_2 test-accuracy:0.75776 test-loss:0.47823 train-accuracy:0.84240 train-loss:0.35878 stage:epoch_3 test-accuracy:0.75983 test-loss:0.48016 train-accuracy:0.84608 train-loss:0.35352 stage:epoch_4 test-accuracy:0.75776 test-loss:0.48063 train-accuracy:0.84793 train-loss:0.34862 stage:epoch_5 test-accuracy:0.75569 test-loss:0.48173 train-accuracy:0.85069 train-loss:0.34419 stage:epoch_6 test-accuracy:0.75776 test-loss:0.48283 train-accuracy:0.85346 train-loss:0.34008 stage:epoch_7 test-accuracy:0.75776 test-loss:0.48381 train-accuracy:0.85806 train-loss:0.33622 stage:epoch_8 test-accuracy:0.75983 test-loss:0.48495 train-accuracy:0.86175 train-loss:0.33260 stage:epoch_9 test-accuracy:0.75983 test-loss:0.48595 train-accuracy:0.86267 train-loss:0.32917 stage:final test-accuracy:0.75983 test-loss:0.48703 train-accuracy:0.86359 train-loss:0.32592
Notice both the test and training accuracy improving during training.
We can now update the YDF model with the finetuned weights.
model_group_AB.update_with_jax_params(jax_model_group_AB.params)
model_group_AB
is the finetuned model. Let's evaluate and compare it to the other models:
accuracy_test_B_model_AB_finetuned_B = model_group_AB.evaluate(
test_ds_group_B
).accuracy
print("Accuracy on B, model trained on A:", accuracy_test_B_model_A)
print("Accuracy on B, model trained on B:", accuracy_test_B_model_B)
print("Accuracy on B, model trained on A+B:", accuracy_test_B_model_AB)
print("==================================")
print(
"Accuracy on B, model trained on A+B, finetuned on B:",
accuracy_test_B_model_AB_finetuned_B,
)
Accuracy on B, model trained on A: 0.7204968944099379 Accuracy on B, model trained on B: 0.7329192546583851 Accuracy on B, model trained on A+B: 0.7556935817805382 ================================== Accuracy on B, model trained on A+B, finetuned on B: 0.7598343685300207
Notice that the new model "Accuracy on B, model trained on A+B" shows the best test accuracy.
model_group_AB
is a YDF model like anyother. For instance, you can save it and analyse it.
# Save the model
with tempfile.TemporaryDirectory() as tempdir:
model_group_AB.save(tempdir)
# Analyse the model
model_group_AB.analyze(test_ds_group_B)
Variable importances measure the importance of an input feature for a model.
1. "capital_gain" 0.049689 ################ 2. "occupation" 0.045549 ############## 3. "education" 0.026915 ######## 4. "education_num" 0.026915 ######## 5. "age" 0.018634 ###### 6. "capital_loss" 0.018634 ###### 7. "workclass" 0.014493 ##### 8. "fnlwgt" 0.002070 # 9. "native_country" 0.002070 # 10. "relationship" 0.000000 11. "race" 0.000000 12. "sex" 0.000000 13. "hours_per_week" 0.000000 14. "marital_status" -0.002070
1. "capital_gain" 0.164288 ################ 2. "capital_loss" 0.048263 ##### 3. "occupation" 0.033196 ### 4. "education" 0.023903 ## 5. "education_num" 0.015137 ## 6. "age" 0.013872 # 7. "workclass" 0.006274 # 8. "race" 0.002477 9. "sex" 0.001453 10. "fnlwgt" 0.000984 11. "marital_status" 0.000722 12. "relationship" 0.000000 13. "native_country" -0.000019 14. "hours_per_week" -0.007143
1. "capital_gain" 0.083385 ################ 2. "occupation" 0.040765 ######## 3. "capital_loss" 0.030647 ###### 4. "education" 0.026051 ##### 5. "age" 0.024419 ##### 6. "education_num" 0.016887 #### 7. "workclass" 0.010427 ## 8. "race" 0.003161 # 9. "marital_status" 0.000790 # 10. "sex" 0.000704 # 11. "relationship" 0.000000 # 12. "native_country" -0.000361 # 13. "fnlwgt" -0.001022 14. "hours_per_week" -0.006107
1. "capital_gain" 0.162868 ################ 2. "capital_loss" 0.048043 ##### 3. "occupation" 0.033135 ### 4. "education" 0.023881 ## 5. "education_num" 0.015116 ## 6. "age" 0.013875 # 7. "workclass" 0.006275 # 8. "race" 0.002472 9. "sex" 0.001448 10. "fnlwgt" 0.000990 11. "marital_status" 0.000721 12. "relationship" 0.000000 13. "native_country" -0.000014 14. "hours_per_week" -0.007106
1. "age" 0.226642 ################ 2. "occupation" 0.219727 ############# 3. "capital_gain" 0.214876 ############ 4. "education" 0.213746 ########### 5. "marital_status" 0.212739 ########### 6. "relationship" 0.206040 ######### 7. "fnlwgt" 0.203843 ######## 8. "hours_per_week" 0.203735 ######## 9. "capital_loss" 0.196549 ###### 10. "native_country" 0.190548 #### 11. "workclass" 0.187795 ### 12. "education_num" 0.184215 ## 13. "race" 0.180495 14. "sex" 0.177647
1. "age" 26.000000 ################ 2. "capital_gain" 26.000000 ################ 3. "marital_status" 20.000000 ############ 4. "relationship" 17.000000 ########## 5. "capital_loss" 14.000000 ######## 6. "hours_per_week" 14.000000 ######## 7. "education" 12.000000 ####### 8. "fnlwgt" 10.000000 ##### 9. "race" 9.000000 ##### 10. "education_num" 7.000000 ### 11. "sex" 4.000000 # 12. "occupation" 2.000000 13. "workclass" 1.000000 14. "native_country" 1.000000
1. "occupation" 724.000000 ################ 2. "fnlwgt" 513.000000 ########### 3. "age" 483.000000 ########## 4. "education" 464.000000 ########## 5. "hours_per_week" 339.000000 ####### 6. "capital_gain" 326.000000 ###### 7. "native_country" 306.000000 ###### 8. "capital_loss" 297.000000 ###### 9. "relationship" 262.000000 ##### 10. "workclass" 244.000000 ##### 11. "marital_status" 210.000000 #### 12. "education_num" 82.000000 # 13. "sex" 42.000000 14. "race" 21.000000
1. "relationship" 3014.690076 ################ 2. "capital_gain" 2065.521668 ########## 3. "education" 1144.490954 ###### 4. "marital_status" 1111.389695 ##### 5. "occupation" 1094.619502 ##### 6. "education_num" 796.666823 #### 7. "capital_loss" 584.055066 ### 8. "age" 582.288569 ### 9. "hours_per_week" 366.856509 # 10. "native_country" 263.872689 # 11. "fnlwgt" 216.537764 # 12. "workclass" 196.085503 # 13. "sex" 47.217730 14. "race" 5.428727