In Java¶
Setup¶
pip install ydf -U
How can I use the Java Standalone export?¶
YDF models can be integrated in two ways:
Direct Code Generation: Call
model.to_standalone_java()to generate the source code. This option is simple and great for experimentation.Build Rule Integration: For production use, save your model (e.g., in Google3) and use a java_ydf_embedded_model Blaze rule. This option automatically call to_standalone_java call during compilation, simplifying model updates and option testing. Note that this build rule is currently not available in the open-source build / Bazel.
Both methods are demonstrated in this tutorial.
Import libraries¶
import pandas as pd
import ydf
Training a small model¶
First, we train a small YDF model on the Adult dataset.
# Download a classification dataset and load it as a Pandas DataFrame.
ds_path = "https://raw.githubusercontent.com/google/yggdrasil-decision-forests/main/yggdrasil_decision_forests/test_data/dataset"
train_ds = pd.read_csv(f"{ds_path}/adult_train.csv")
model = ydf.GradientBoostedTreesLearner(label="income", num_trees=2).train(
train_ds
)
# Note: Only train 2 trees to make the generated code smaller.
model.describe()
Train model on 22792 examples Model trained in 0:00:00.025254
Task : CLASSIFICATION
Label : income
Features (14) : age workclass fnlwgt education education_num marital_status occupation relationship race sex capital_gain capital_loss hours_per_week native_country
Weights : None
Trained with tuner : No
Trained with Feature Selection : No
Model size : 40 kB
Number of records: 22792 Number of columns: 15 Number of columns by type: CATEGORICAL: 9 (60%) NUMERICAL: 6 (40%) Columns: CATEGORICAL: 9 (60%) 0: "income" CATEGORICAL has-dict vocab-size:3 zero-ood-items most-frequent:"<=50K" 17308 (75.9389%) dtype:DTYPE_BYTES 2: "workclass" CATEGORICAL num-nas:1257 (5.51509%) has-dict vocab-size:8 num-oods:3 (0.0139308%) most-frequent:"Private" 15879 (73.7358%) dtype:DTYPE_BYTES 4: "education" CATEGORICAL has-dict vocab-size:17 zero-ood-items most-frequent:"HS-grad" 7340 (32.2043%) dtype:DTYPE_BYTES 6: "marital_status" CATEGORICAL has-dict vocab-size:8 zero-ood-items most-frequent:"Married-civ-spouse" 10431 (45.7661%) dtype:DTYPE_BYTES 7: "occupation" CATEGORICAL num-nas:1260 (5.52826%) has-dict vocab-size:14 num-oods:4 (0.018577%) most-frequent:"Prof-specialty" 2870 (13.329%) dtype:DTYPE_BYTES 8: "relationship" CATEGORICAL has-dict vocab-size:7 zero-ood-items most-frequent:"Husband" 9191 (40.3256%) dtype:DTYPE_BYTES 9: "race" CATEGORICAL has-dict vocab-size:6 zero-ood-items most-frequent:"White" 19467 (85.4115%) dtype:DTYPE_BYTES 10: "sex" CATEGORICAL has-dict vocab-size:3 zero-ood-items most-frequent:"Male" 15165 (66.5365%) dtype:DTYPE_BYTES 14: "native_country" CATEGORICAL num-nas:407 (1.78571%) has-dict vocab-size:41 num-oods:1 (0.00446728%) most-frequent:"United-States" 20436 (91.2933%) dtype:DTYPE_BYTES NUMERICAL: 6 (40%) 1: "age" NUMERICAL mean:38.6153 min:17 max:90 sd:13.661 dtype:DTYPE_INT64 3: "fnlwgt" NUMERICAL mean:189879 min:12285 max:1.4847e+06 sd:106423 dtype:DTYPE_INT64 5: "education_num" NUMERICAL mean:10.0927 min:1 max:16 sd:2.56427 dtype:DTYPE_INT64 11: "capital_gain" NUMERICAL mean:1081.9 min:0 max:99999 sd:7509.48 dtype:DTYPE_INT64 12: "capital_loss" NUMERICAL mean:87.2806 min:0 max:4356 sd:403.01 dtype:DTYPE_INT64 13: "hours_per_week" NUMERICAL mean:40.3955 min:1 max:99 sd:12.249 dtype:DTYPE_INT64 Terminology: nas: Number of non-available (i.e. missing) values. ood: Out of dictionary. manually-defined: Attribute whose type is manually defined by the user, i.e., the type was not automatically inferred. tokenized: The attribute value is obtained through tokenization. has-dict: The attribute is attached to a string dictionary e.g. a categorical attribute stored as a string. vocab-size: Number of unique values.
The following evaluation is computed on the validation or out-of-bag dataset.
Task: CLASSIFICATION
Label: income
Loss (BINOMIAL_LOG_LIKELIHOOD): 1.00595
Accuracy: 0.736609 CI95[W][0 1]
ErrorRate: : 0.263391
Confusion Table:
truth\prediction
<=50K >50K
<=50K 1664 0
>50K 595 0
Total: 2259
Variable importances measure the importance of an input feature for a model.
1. "relationship" 1.000000 ################
2. "capital_gain" 0.393125 ####
3. "education_num" 0.271300 #
4. "age" 0.213460
5. "education" 0.200074
6. "occupation" 0.189986
7. "capital_loss" 0.186946
8. "fnlwgt" 0.173264
9. "hours_per_week" 0.172198
10. "workclass" 0.170141
11. "native_country" 0.170141
1. "relationship" 2.000000
1. "capital_gain" 10.000000 ################
2. "age" 9.000000 ##############
3. "occupation" 8.000000 ############
4. "capital_loss" 8.000000 ############
5. "education" 5.000000 #######
6. "fnlwgt" 4.000000 #####
7. "education_num" 4.000000 #####
8. "hours_per_week" 3.000000 ###
9. "relationship" 2.000000 #
10. "workclass" 1.000000
11. "native_country" 1.000000
1. "relationship" 1358.045196 ################
2. "capital_gain" 592.782820 ######
3. "education_num" 581.188269 ######
4. "occupation" 153.072061 #
5. "capital_loss" 80.772546
6. "education" 80.057732
7. "age" 56.385846
8. "hours_per_week" 8.637064
9. "fnlwgt" 5.569371
10. "native_country" 3.053526
11. "workclass" 0.221624
Those variable importances are computed during training. More, and possibly more informative, variable importances are available when analyzing a model on a test dataset.
Below is the first tree of the model. The model contains 2 trees, which jointly make the prediction. Other trees can be printed with `model.print_tree(tree_idx)` or plotted with `model.plot_tree(tree_idx)`
"relationship" is in [BITMAP] {<OOD>, Husband, Wife} [s:0.036623 n:20533 np:9213 miss:1] ; pred:-4.15883e-09
├─(pos)─ "education_num">=12.5 [s:0.0343752 n:9213 np:2773 miss:0] ; pred:0.116933
| ├─(pos)─ "capital_gain">=5095.5 [s:0.0125728 n:2773 np:434 miss:0] ; pred:0.272683
| | ├─(pos)─ "occupation" is in [BITMAP] {<OOD>, Prof-specialty, Exec-managerial, Craft-repair, Adm-clerical, Sales, Other-service, Machine-op-inspct, Transport-moving, Handlers-cleaners, ...[2 left]} [s:0.000434532 n:434 np:429 miss:1] ; pred:0.416173
| | | ├─(pos)─ "age">=79.5 [s:0.000449964 n:429 np:5 miss:0] ; pred:0.417414
| | | | ├─(pos)─ pred:0.309737
| | | | └─(neg)─ pred:0.418684
| | | └─(neg)─ pred:0.309737
| | └─(neg)─ "capital_loss">=1782.5 [s:0.0101181 n:2339 np:249 miss:0] ; pred:0.246058
| | ├─(pos)─ "capital_loss">=1989.5 [s:0.00201289 n:249 np:39 miss:0] ; pred:0.406701
| | | ├─(pos)─ pred:0.349312
| | | └─(neg)─ pred:0.417359
| | └─(neg)─ "occupation" is in [BITMAP] {Prof-specialty, Exec-managerial, Sales, Tech-support, Protective-serv} [s:0.0097175 n:2090 np:1688 miss:0] ; pred:0.226919
| | ├─(pos)─ pred:0.253437
| | └─(neg)─ pred:0.11557
| └─(neg)─ "capital_gain">=5095.5 [s:0.0205419 n:6440 np:303 miss:0] ; pred:0.0498685
| ├─(pos)─ "age">=60.5 [s:0.00421502 n:303 np:43 miss:0] ; pred:0.40543
| | ├─(pos)─ "occupation" is in [BITMAP] {Prof-specialty, Exec-managerial, Adm-clerical, Sales, Machine-op-inspct, Transport-moving, Handlers-cleaners} [s:0.0296244 n:43 np:25 miss:0] ; pred:0.317428
| | | ├─(pos)─ pred:0.397934
| | | └─(neg)─ pred:0.205614
| | └─(neg)─ "fnlwgt">=36212.5 [s:1.36643e-16 n:260 np:250 miss:1] ; pred:0.419984
| | ├─(pos)─ pred:0.419984
| | └─(neg)─ pred:0.419984
| └─(neg)─ "occupation" is in [BITMAP] {Prof-specialty, Exec-managerial, Adm-clerical, Sales, Tech-support, Protective-serv} [s:0.0100346 n:6137 np:2334 miss:0] ; pred:0.0323136
| ├─(pos)─ "age">=33.5 [s:0.00939348 n:2334 np:1769 miss:1] ; pred:0.102799
| | ├─(pos)─ pred:0.132992
| | └─(neg)─ pred:0.00826457
| └─(neg)─ "education" is in [BITMAP] {<OOD>, HS-grad, Some-college, Bachelors, Masters, Assoc-voc, Assoc-acdm, Prof-school, Doctorate} [s:0.00478423 n:3803 np:2941 miss:1] ; pred:-0.0109452
| ├─(pos)─ pred:0.00969668
| └─(neg)─ pred:-0.0813718
└─(neg)─ "capital_gain">=7073.5 [s:0.0143125 n:11320 np:199 miss:0] ; pred:-0.0951681
├─(pos)─ "age">=21.5 [s:0.00807667 n:199 np:194 miss:1] ; pred:0.397823
| ├─(pos)─ "capital_gain">=7565.5 [s:0.00761118 n:194 np:184 miss:0] ; pred:0.405777
| | ├─(pos)─ "capital_gain">=30961.5 [s:0.000242202 n:184 np:20 miss:0] ; pred:0.416988
| | | ├─(pos)─ pred:0.392422
| | | └─(neg)─ pred:0.419984
| | └─(neg)─ "education" is in [BITMAP] {Bachelors, Masters, Assoc-voc, Prof-school} [s:0.16 n:10 np:5 miss:0] ; pred:0.19949
| | ├─(pos)─ pred:0.419984
| | └─(neg)─ pred:-0.0210046
| └─(neg)─ pred:0.0892425
└─(neg)─ "education" is in [BITMAP] {<OOD>, Bachelors, Masters, Prof-school, Doctorate} [s:0.00229611 n:11121 np:2199 miss:1] ; pred:-0.10399
├─(pos)─ "age">=31.5 [s:0.00725859 n:2199 np:1263 miss:1] ; pred:-0.0507848
| ├─(pos)─ "education" is in [BITMAP] {<OOD>, HS-grad, Some-college, Assoc-voc, 11th, Assoc-acdm, 10th, 7th-8th, Prof-school, 9th, ...[5 left]} [s:0.0110157 n:1263 np:125 miss:1] ; pred:-0.0103552
| | ├─(pos)─ pred:0.16421
| | └─(neg)─ pred:-0.0295298
| └─(neg)─ "capital_loss">=1977 [s:0.00164232 n:936 np:5 miss:0] ; pred:-0.105339
| ├─(pos)─ pred:0.19949
| └─(neg)─ pred:-0.106976
└─(neg)─ "capital_loss">=2218.5 [s:0.000534265 n:8922 np:41 miss:0] ; pred:-0.117103
├─(pos)─ "fnlwgt">=125450 [s:0.0755454 n:41 np:28 miss:1] ; pred:0.0704198
| ├─(pos)─ pred:-0.0328167
| └─(neg)─ pred:0.292776
└─(neg)─ "hours_per_week">=40.5 [s:0.000447024 n:8881 np:1559 miss:0] ; pred:-0.117969
├─(pos)─ pred:-0.0927111
└─(neg)─ pred:-0.123347
Direct Code Generation¶
Let's generate the model .java file and the model data .bin file.
The .java file contains the following symbols:
Instanceclass: An input example.Predictmethod: A thread safe method that consumes an Instance and returns a label class / probability (for classification) or value (for regression).Labelenum: The label values. In this case, this is a binary classification model with two labelsLabel.LT_50KandLabel.GT_50K.- Categorical enums: An enum class for each of the categorical input features e.g. FeatureWorkclass, FeatureEducation.
The model data is stored in a separate .bin file, which needs to be in the classpath when running the model.
# Generate the Java code and binary data
java_model_files = model.to_standalone_java(export_dir=".")
# Print the content of the Java file
print(java_model_files["YdfModel.java"].decode())
Save the contents of java_model_files["YdfModelData.bin"] in the classpath.
import ydf_model.YdfModel;
import ydf_model.YdfModel.Instance;
import ydf_model.YdfModel.Label;
import ydf_model.YdfModel.FeatureWorkclass;
import ydf_model.YdfModel.FeatureEducation;
import ydf_model.YdfModel.FeatureMaritalStatus;
import ydf_model.YdfModel.FeatureOccupation;
import ydf_model.YdfModel.FeatureRelationship;
import ydf_model.YdfModel.FeatureRace;
import ydf_model.YdfModel.FeatureSex;
import ydf_model.YdfModel.FeatureNativeCountry;
public class Predictor {
public static void main(String[] args) {
try {
YdfModel model = new YdfModel(); // Loads data from YdfModel.bin in classpath
Instance instance = new Instance();
instance.age = 39;
instance.workclass = FeatureWorkclass.STATE_GOV;
instance.fnlwgt = 77516;
instance.education = FeatureEducation.BACHELORS;
instance.education_num = 13;
instance.marital_status = FeatureMaritalStatus.NEVER_MARRIED;
instance.occupation = FeatureOccupation.ADM_CLERICAL;
instance.relationship = FeatureRelationship.NOT_IN_FAMILY;
instance.race = FeatureRace.WHITE;
instance.sex = FeatureSex.MALE;
instance.capital_gain = 2174;
instance.capital_loss = 0;
instance.hours_per_week = 40;
instance.native_country = FeatureNativeCountry.UNITED_STATES;
Label prediction = model.Predict(instance);
if (prediction == Label.LT_50K) {
System.out.println("Prediction: <=50K");
} else if (prediction == Label.GT_50K) {
System.out.println("Prediction: >50K");
}
} catch (Exception e) {
e.printStackTrace();
}
}
}
By default, Predict returns a class for classification model. Instead, the
method can return a probability (or probabilities in case of multi-class) or
scores (e.g., logits) with the classification_output argument. For example:
model.to_standalone_java(classification_output='PROBABILITY'): Returns a probabilitiy (float) or probabilities (std::array<float>).model.to_standalone_java(classification_output='SCORE'): Returns scores.
Categorical feature values are created from the corresponding enum class e.g. FeatureRelationship.NOT_IN_FAMILY.
If you look at the content of the Predict function, you will see a for-loop
over the trees and a while-loop over the nodes. This is called the "routing"
algorithm, and it is a simple and generally efficient way to generate
predictions with a decision forest.
Build Rule Integration¶
Instead of saving manually the result of model.to_standalone_java() to a file,
you can use the java_ydf_standalone_model Blaze/Bazel rule. The steps are:
1.
Save the model with model.save(...) in a new directory in your source code
(e.g., in Google3).
model.save("my_project/ydf_model_data")
2.
Create a BUILD file with a filegroup in the model directory:
File: my_project/ydf_model_data/BUILD
filegroup(name = "ydf_model_data", srcs = glob(["**"]))
3.
In your library's BUILD, create a java_ydf_standalone_model build rule.
File: my_project/BUILD
load("//third_party/yggdrasil_decision_forests/serving/embed:embed.bzl", "java_ydf_standalone_model")
java_ydf_standalone_model (
name = "ydf_mode", # Rule name, .java filename, generated .bin filename.
package_name = "ydf_model", # Name of the Java package where this rule is defined.
data = "//my_project/ydf_model_data",
# Compilation options here.
classification_output = "PROBABILITY",
constraints = ["android"], # Add this if building for android.
)
4.
In your java_library, add ":my_model" as a dependency.
File: my_project/BUILD python java_library( name = "main", srcs = ["MyClass.java"], deps = [":ydf_model"], )
5.
In your Java code, import and call the model as shown in the example above.