In C++ [Standalone]¶
Setup¶
pip install ydf -U
What is C++ Standalone?¶
Once trained, YDF models can be integrated into your C++ software using one of two solutions:
YDF Lib: Copy your model data into your binary (or copy it in a directory accessible by your binary) and load it using the YDF library. This approach lets you change the model without recompiling your library, as detailed in the In C++ tutorial.
YDF Standalone (this tutorial): Compile your model into a dependency-free .h file that you include directly in your code. This solution generates significantly smaller code (up to 700x reduction observed), has no YDF dependency improving portability, and offers a simpler API.
How to use C++ Standalone?¶
YDF models can be integrated in two ways:
Direct Code Generation: Call
model.to_standalone_cc()to generate the source code. This option is simple and great for experimentation.Build Rule Integration: For production, save your model (e.g., in Google3) and use a cc_ydf_embedded_model Blaze/Bazel rule. This option automatically call to_standalone_cc call during compilation, simplifying model updates and option testing.
Both methods are demonstrated in this tutorial.
Import libraries¶
import pandas as pd
import ydf
Training a small model¶
First, we train a small YDF model on the Adult dataset.
# Download a classification dataset and load it as a Pandas DataFrame.
ds_path = "https://raw.githubusercontent.com/google/yggdrasil-decision-forests/main/yggdrasil_decision_forests/test_data/dataset"
train_ds = pd.read_csv(f"{ds_path}/adult_train.csv")
model = ydf.GradientBoostedTreesLearner(label="income", num_trees=2).train(
train_ds
)
# Note: Only train 2 trees to make the generated code smaller.
model.describe()
Train model on 22792 examples Model trained in 0:00:00.025254
Task : CLASSIFICATION
Label : income
Features (14) : age workclass fnlwgt education education_num marital_status occupation relationship race sex capital_gain capital_loss hours_per_week native_country
Weights : None
Trained with tuner : No
Trained with Feature Selection : No
Model size : 40 kB
Number of records: 22792 Number of columns: 15 Number of columns by type: CATEGORICAL: 9 (60%) NUMERICAL: 6 (40%) Columns: CATEGORICAL: 9 (60%) 0: "income" CATEGORICAL has-dict vocab-size:3 zero-ood-items most-frequent:"<=50K" 17308 (75.9389%) dtype:DTYPE_BYTES 2: "workclass" CATEGORICAL num-nas:1257 (5.51509%) has-dict vocab-size:8 num-oods:3 (0.0139308%) most-frequent:"Private" 15879 (73.7358%) dtype:DTYPE_BYTES 4: "education" CATEGORICAL has-dict vocab-size:17 zero-ood-items most-frequent:"HS-grad" 7340 (32.2043%) dtype:DTYPE_BYTES 6: "marital_status" CATEGORICAL has-dict vocab-size:8 zero-ood-items most-frequent:"Married-civ-spouse" 10431 (45.7661%) dtype:DTYPE_BYTES 7: "occupation" CATEGORICAL num-nas:1260 (5.52826%) has-dict vocab-size:14 num-oods:4 (0.018577%) most-frequent:"Prof-specialty" 2870 (13.329%) dtype:DTYPE_BYTES 8: "relationship" CATEGORICAL has-dict vocab-size:7 zero-ood-items most-frequent:"Husband" 9191 (40.3256%) dtype:DTYPE_BYTES 9: "race" CATEGORICAL has-dict vocab-size:6 zero-ood-items most-frequent:"White" 19467 (85.4115%) dtype:DTYPE_BYTES 10: "sex" CATEGORICAL has-dict vocab-size:3 zero-ood-items most-frequent:"Male" 15165 (66.5365%) dtype:DTYPE_BYTES 14: "native_country" CATEGORICAL num-nas:407 (1.78571%) has-dict vocab-size:41 num-oods:1 (0.00446728%) most-frequent:"United-States" 20436 (91.2933%) dtype:DTYPE_BYTES NUMERICAL: 6 (40%) 1: "age" NUMERICAL mean:38.6153 min:17 max:90 sd:13.661 dtype:DTYPE_INT64 3: "fnlwgt" NUMERICAL mean:189879 min:12285 max:1.4847e+06 sd:106423 dtype:DTYPE_INT64 5: "education_num" NUMERICAL mean:10.0927 min:1 max:16 sd:2.56427 dtype:DTYPE_INT64 11: "capital_gain" NUMERICAL mean:1081.9 min:0 max:99999 sd:7509.48 dtype:DTYPE_INT64 12: "capital_loss" NUMERICAL mean:87.2806 min:0 max:4356 sd:403.01 dtype:DTYPE_INT64 13: "hours_per_week" NUMERICAL mean:40.3955 min:1 max:99 sd:12.249 dtype:DTYPE_INT64 Terminology: nas: Number of non-available (i.e. missing) values. ood: Out of dictionary. manually-defined: Attribute whose type is manually defined by the user, i.e., the type was not automatically inferred. tokenized: The attribute value is obtained through tokenization. has-dict: The attribute is attached to a string dictionary e.g. a categorical attribute stored as a string. vocab-size: Number of unique values.
The following evaluation is computed on the validation or out-of-bag dataset.
Task: CLASSIFICATION
Label: income
Loss (BINOMIAL_LOG_LIKELIHOOD): 1.00595
Accuracy: 0.736609 CI95[W][0 1]
ErrorRate: : 0.263391
Confusion Table:
truth\prediction
<=50K >50K
<=50K 1664 0
>50K 595 0
Total: 2259
Variable importances measure the importance of an input feature for a model.
1. "relationship" 1.000000 ################
2. "capital_gain" 0.393125 ####
3. "education_num" 0.271300 #
4. "age" 0.213460
5. "education" 0.200074
6. "occupation" 0.189986
7. "capital_loss" 0.186946
8. "fnlwgt" 0.173264
9. "hours_per_week" 0.172198
10. "workclass" 0.170141
11. "native_country" 0.170141
1. "relationship" 2.000000
1. "capital_gain" 10.000000 ################
2. "age" 9.000000 ##############
3. "occupation" 8.000000 ############
4. "capital_loss" 8.000000 ############
5. "education" 5.000000 #######
6. "fnlwgt" 4.000000 #####
7. "education_num" 4.000000 #####
8. "hours_per_week" 3.000000 ###
9. "relationship" 2.000000 #
10. "workclass" 1.000000
11. "native_country" 1.000000
1. "relationship" 1358.045196 ################
2. "capital_gain" 592.782820 ######
3. "education_num" 581.188269 ######
4. "occupation" 153.072061 #
5. "capital_loss" 80.772546
6. "education" 80.057732
7. "age" 56.385846
8. "hours_per_week" 8.637064
9. "fnlwgt" 5.569371
10. "native_country" 3.053526
11. "workclass" 0.221624
Those variable importances are computed during training. More, and possibly more informative, variable importances are available when analyzing a model on a test dataset.
Below is the first tree of the model. The model contains 2 trees, which jointly make the prediction. Other trees can be printed with `model.print_tree(tree_idx)` or plotted with `model.plot_tree(tree_idx)`
"relationship" is in [BITMAP] {<OOD>, Husband, Wife} [s:0.036623 n:20533 np:9213 miss:1] ; pred:-4.15883e-09
├─(pos)─ "education_num">=12.5 [s:0.0343752 n:9213 np:2773 miss:0] ; pred:0.116933
| ├─(pos)─ "capital_gain">=5095.5 [s:0.0125728 n:2773 np:434 miss:0] ; pred:0.272683
| | ├─(pos)─ "occupation" is in [BITMAP] {<OOD>, Prof-specialty, Exec-managerial, Craft-repair, Adm-clerical, Sales, Other-service, Machine-op-inspct, Transport-moving, Handlers-cleaners, ...[2 left]} [s:0.000434532 n:434 np:429 miss:1] ; pred:0.416173
| | | ├─(pos)─ "age">=79.5 [s:0.000449964 n:429 np:5 miss:0] ; pred:0.417414
| | | | ├─(pos)─ pred:0.309737
| | | | └─(neg)─ pred:0.418684
| | | └─(neg)─ pred:0.309737
| | └─(neg)─ "capital_loss">=1782.5 [s:0.0101181 n:2339 np:249 miss:0] ; pred:0.246058
| | ├─(pos)─ "capital_loss">=1989.5 [s:0.00201289 n:249 np:39 miss:0] ; pred:0.406701
| | | ├─(pos)─ pred:0.349312
| | | └─(neg)─ pred:0.417359
| | └─(neg)─ "occupation" is in [BITMAP] {Prof-specialty, Exec-managerial, Sales, Tech-support, Protective-serv} [s:0.0097175 n:2090 np:1688 miss:0] ; pred:0.226919
| | ├─(pos)─ pred:0.253437
| | └─(neg)─ pred:0.11557
| └─(neg)─ "capital_gain">=5095.5 [s:0.0205419 n:6440 np:303 miss:0] ; pred:0.0498685
| ├─(pos)─ "age">=60.5 [s:0.00421502 n:303 np:43 miss:0] ; pred:0.40543
| | ├─(pos)─ "occupation" is in [BITMAP] {Prof-specialty, Exec-managerial, Adm-clerical, Sales, Machine-op-inspct, Transport-moving, Handlers-cleaners} [s:0.0296244 n:43 np:25 miss:0] ; pred:0.317428
| | | ├─(pos)─ pred:0.397934
| | | └─(neg)─ pred:0.205614
| | └─(neg)─ "fnlwgt">=36212.5 [s:1.36643e-16 n:260 np:250 miss:1] ; pred:0.419984
| | ├─(pos)─ pred:0.419984
| | └─(neg)─ pred:0.419984
| └─(neg)─ "occupation" is in [BITMAP] {Prof-specialty, Exec-managerial, Adm-clerical, Sales, Tech-support, Protective-serv} [s:0.0100346 n:6137 np:2334 miss:0] ; pred:0.0323136
| ├─(pos)─ "age">=33.5 [s:0.00939348 n:2334 np:1769 miss:1] ; pred:0.102799
| | ├─(pos)─ pred:0.132992
| | └─(neg)─ pred:0.00826457
| └─(neg)─ "education" is in [BITMAP] {<OOD>, HS-grad, Some-college, Bachelors, Masters, Assoc-voc, Assoc-acdm, Prof-school, Doctorate} [s:0.00478423 n:3803 np:2941 miss:1] ; pred:-0.0109452
| ├─(pos)─ pred:0.00969668
| └─(neg)─ pred:-0.0813718
└─(neg)─ "capital_gain">=7073.5 [s:0.0143125 n:11320 np:199 miss:0] ; pred:-0.0951681
├─(pos)─ "age">=21.5 [s:0.00807667 n:199 np:194 miss:1] ; pred:0.397823
| ├─(pos)─ "capital_gain">=7565.5 [s:0.00761118 n:194 np:184 miss:0] ; pred:0.405777
| | ├─(pos)─ "capital_gain">=30961.5 [s:0.000242202 n:184 np:20 miss:0] ; pred:0.416988
| | | ├─(pos)─ pred:0.392422
| | | └─(neg)─ pred:0.419984
| | └─(neg)─ "education" is in [BITMAP] {Bachelors, Masters, Assoc-voc, Prof-school} [s:0.16 n:10 np:5 miss:0] ; pred:0.19949
| | ├─(pos)─ pred:0.419984
| | └─(neg)─ pred:-0.0210046
| └─(neg)─ pred:0.0892425
└─(neg)─ "education" is in [BITMAP] {<OOD>, Bachelors, Masters, Prof-school, Doctorate} [s:0.00229611 n:11121 np:2199 miss:1] ; pred:-0.10399
├─(pos)─ "age">=31.5 [s:0.00725859 n:2199 np:1263 miss:1] ; pred:-0.0507848
| ├─(pos)─ "education" is in [BITMAP] {<OOD>, HS-grad, Some-college, Assoc-voc, 11th, Assoc-acdm, 10th, 7th-8th, Prof-school, 9th, ...[5 left]} [s:0.0110157 n:1263 np:125 miss:1] ; pred:-0.0103552
| | ├─(pos)─ pred:0.16421
| | └─(neg)─ pred:-0.0295298
| └─(neg)─ "capital_loss">=1977 [s:0.00164232 n:936 np:5 miss:0] ; pred:-0.105339
| ├─(pos)─ pred:0.19949
| └─(neg)─ pred:-0.106976
└─(neg)─ "capital_loss">=2218.5 [s:0.000534265 n:8922 np:41 miss:0] ; pred:-0.117103
├─(pos)─ "fnlwgt">=125450 [s:0.0755454 n:41 np:28 miss:1] ; pred:0.0704198
| ├─(pos)─ pred:-0.0328167
| └─(neg)─ pred:0.292776
└─(neg)─ "hours_per_week">=40.5 [s:0.000447024 n:8881 np:1559 miss:0] ; pred:-0.117969
├─(pos)─ pred:-0.0927111
└─(neg)─ pred:-0.123347
Direct Code Generation¶
Let's generate the model .h file. It contains the following symbols.
Instancestruct: An input example. Each input feature is an attribute (e.g., age, workclass).Predictfunction: A thread safe function that consumes an Instance and returns a label class (for classification).Label: The label values. In this case, this is a binary classification model with two labelsLabel::kLt50KandLabel::kGt50K.- Categorical enums: An enum class for each of the categorical input features e.g. FeatureWorkclass, FeatureEducation.
print(model.to_standalone_cc())
#ifndef YDF_MODEL_YDF_MODEL_H_
#define YDF_MODEL_YDF_MODEL_H_
#include <stdint.h>
#include <cstring>
#include <array>
#include <algorithm>
#include <bitset>
#include <cassert>
namespace ydf_model {
enum class Label : uint32_t {
kLt50K = 0,
kGt50K = 1,
};
enum class FeatureWorkclass : uint32_t {
kOutOfVocabulary = 0,
kPrivate = 1,
kSelfEmpNotInc = 2,
kLocalGov = 3,
kStateGov = 4,
kSelfEmpInc = 5,
kFederalGov = 6,
kWithoutPay = 7,
};
enum class FeatureEducation : uint32_t {
kOutOfVocabulary = 0,
kHsGrad = 1,
kSomeCollege = 2,
kBachelors = 3,
kMasters = 4,
kAssocVoc = 5,
k11th = 6,
kAssocAcdm = 7,
k10th = 8,
k7th8th = 9,
kProfSchool = 10,
k9th = 11,
k12th = 12,
kDoctorate = 13,
k5th6th = 14,
k1st4th = 15,
kPreschool = 16,
};
enum class FeatureMaritalStatus : uint32_t {
kOutOfVocabulary = 0,
kMarriedCivSpouse = 1,
kNeverMarried = 2,
kDivorced = 3,
kWidowed = 4,
kSeparated = 5,
kMarriedSpouseAbsent = 6,
kMarriedAfSpouse = 7,
};
enum class FeatureOccupation : uint32_t {
kOutOfVocabulary = 0,
kProfSpecialty = 1,
kExecManagerial = 2,
kCraftRepair = 3,
kAdmClerical = 4,
kSales = 5,
kOtherService = 6,
kMachineOpInspct = 7,
kTransportMoving = 8,
kHandlersCleaners = 9,
kFarmingFishing = 10,
kTechSupport = 11,
kProtectiveServ = 12,
kPrivHouseServ = 13,
};
enum class FeatureRelationship : uint32_t {
kOutOfVocabulary = 0,
kHusband = 1,
kNotInFamily = 2,
kOwnChild = 3,
kUnmarried = 4,
kWife = 5,
kOtherRelative = 6,
};
enum class FeatureRace : uint32_t {
kOutOfVocabulary = 0,
kWhite = 1,
kBlack = 2,
kAsianPacIslander = 3,
kAmerIndianEskimo = 4,
kOther = 5,
};
enum class FeatureSex : uint32_t {
kOutOfVocabulary = 0,
kMale = 1,
kFemale = 2,
};
enum class FeatureNativeCountry : uint32_t {
kOutOfVocabulary = 0,
kUnitedStates = 1,
kMexico = 2,
kPhilippines = 3,
kGermany = 4,
kCanada = 5,
kPuertoRico = 6,
kIndia = 7,
kElSalvador = 8,
kCuba = 9,
kEngland = 10,
kJamaica = 11,
kDominicanRepublic = 12,
kSouth = 13,
kChina = 14,
kItaly = 15,
kColumbia = 16,
kGuatemala = 17,
kJapan = 18,
kVietnam = 19,
kTaiwan = 20,
kIran = 21,
kPoland = 22,
kHaiti = 23,
kNicaragua = 24,
kGreece = 25,
kPortugal = 26,
kEcuador = 27,
kFrance = 28,
kPeru = 29,
kThailand = 30,
kCambodia = 31,
kIreland = 32,
kLaos = 33,
kYugoslavia = 34,
kTrinadadTobago = 35,
kHonduras = 36,
kHong = 37,
kHungary = 38,
kScotland = 39,
kOutlyingUsGuamUsviEtc = 40,
};
constexpr const int kNumFeatures = 14;
constexpr const int kNumTrees = 2;
struct Instance {
typedef int32_t Numerical;
Numerical age;
FeatureWorkclass workclass;
Numerical fnlwgt;
FeatureEducation education;
Numerical education_num;
FeatureMaritalStatus marital_status;
FeatureOccupation occupation;
FeatureRelationship relationship;
FeatureRace race;
FeatureSex sex;
Numerical capital_gain;
Numerical capital_loss;
Numerical hours_per_week;
FeatureNativeCountry native_country;
};
struct __attribute__((packed)) Node {
uint8_t pos = 0;
union {
struct {
uint8_t feat;
union {
int32_t thr;
uint16_t cat;
};
} cond;
struct {
float val;
} leaf;
};
};
static const Node nodes[] = {
{.pos=25,.cond={.feat=7,.cat=51}},
{.pos=15,.cond={.feat=10,.thr=7074}},
{.pos=7,.cond={.feat=3,.cat=0}},
{.pos=3,.cond={.feat=11,.thr=2219}},
{.pos=1,.cond={.feat=12,.thr=41}},
{.leaf={.val=-0.123347}},
{.leaf={.val=-0.0927111}},
{.pos=1,.cond={.feat=2,.thr=125451}},
{.leaf={.val=0.292776}},
{.leaf={.val=-0.0328167}},
{.pos=3,.cond={.feat=0,.thr=32}},
{.pos=1,.cond={.feat=11,.thr=1977}},
{.leaf={.val=-0.106976}},
{.leaf={.val=0.19949}},
{.pos=1,.cond={.feat=3,.cat=17}},
{.leaf={.val=-0.0295298}},
{.leaf={.val=0.16421}},
{.pos=1,.cond={.feat=0,.thr=22}},
{.leaf={.val=0.0892425}},
{.pos=3,.cond={.feat=10,.thr=7566}},
{.pos=1,.cond={.feat=3,.cat=34}},
{.leaf={.val=-0.0210046}},
{.leaf={.val=0.419984}},
{.pos=1,.cond={.feat=10,.thr=30962}},
{.leaf={.val=0.419984}},
{.leaf={.val=0.392422}},
{.pos=15,.cond={.feat=4,.thr=13}},
{.pos=7,.cond={.feat=10,.thr=5096}},
{.pos=3,.cond={.feat=6,.cat=75}},
{.pos=1,.cond={.feat=3,.cat=58}},
{.leaf={.val=-0.0813718}},
{.leaf={.val=0.00969668}},
{.pos=1,.cond={.feat=0,.thr=34}},
{.leaf={.val=0.00826457}},
{.leaf={.val=0.132992}},
{.pos=3,.cond={.feat=0,.thr=61}},
{.pos=1,.cond={.feat=2,.thr=36213}},
{.leaf={.val=0.419984}},
{.leaf={.val=0.419984}},
{.pos=1,.cond={.feat=6,.cat=89}},
{.leaf={.val=0.205614}},
{.leaf={.val=0.397934}},
{.pos=7,.cond={.feat=10,.thr=5096}},
{.pos=3,.cond={.feat=11,.thr=1783}},
{.pos=1,.cond={.feat=6,.cat=103}},
{.leaf={.val=0.11557}},
{.leaf={.val=0.253437}},
{.pos=1,.cond={.feat=11,.thr=1990}},
{.leaf={.val=0.417359}},
{.leaf={.val=0.349312}},
{.pos=1,.cond={.feat=6,.cat=117}},
{.leaf={.val=0.309737}},
{.pos=1,.cond={.feat=0,.thr=80}},
{.leaf={.val=0.418684}},
{.leaf={.val=0.309737}},
{.pos=25,.cond={.feat=7,.cat=148}},
{.pos=15,.cond={.feat=10,.thr=7074}},
{.pos=7,.cond={.feat=3,.cat=131}},
{.pos=3,.cond={.feat=11,.thr=2219}},
{.pos=1,.cond={.feat=12,.thr=41}},
{.leaf={.val=-0.11917}},
{.leaf={.val=-0.0879641}},
{.pos=1,.cond={.feat=2,.thr=125451}},
{.leaf={.val=0.227849}},
{.leaf={.val=-0.0300817}},
{.pos=3,.cond={.feat=0,.thr=32}},
{.pos=1,.cond={.feat=12,.thr=45}},
{.leaf={.val=-0.114477}},
{.leaf={.val=-0.0633502}},
{.pos=1,.cond={.feat=4,.thr=15}},
{.leaf={.val=-0.0270186}},
{.leaf={.val=0.13565}},
{.pos=1,.cond={.feat=0,.thr=22}},
{.leaf={.val=0.0765646}},
{.pos=3,.cond={.feat=10,.thr=7566}},
{.pos=1,.cond={.feat=4,.thr=11}},
{.leaf={.val=-0.0191264}},
{.leaf={.val=0.310248}},
{.pos=1,.cond={.feat=10,.thr=30962}},
{.leaf={.val=0.310248}},
{.leaf={.val=0.293003}},
{.pos=15,.cond={.feat=4,.thr=13}},
{.pos=7,.cond={.feat=10,.thr=5096}},
{.pos=3,.cond={.feat=6,.cat=155}},
{.pos=1,.cond={.feat=11,.thr=1794}},
{.leaf={.val=-0.0159635}},
{.leaf={.val=0.192961}},
{.pos=1,.cond={.feat=11,.thr=1783}},
{.leaf={.val=0.0773945}},
{.leaf={.val=0.291468}},
{.pos=3,.cond={.feat=0,.thr=61}},
{.pos=1,.cond={.feat=2,.thr=45794}},
{.leaf={.val=0.310248}},
{.leaf={.val=0.310248}},
{.pos=1,.cond={.feat=6,.cat=169}},
{.leaf={.val=0.166497}},
{.leaf={.val=0.296477}},
{.pos=7,.cond={.feat=10,.thr=5096}},
{.pos=3,.cond={.feat=6,.cat=224}},
{.pos=1,.cond={.feat=13,.cat=183}},
{.leaf={.val=0.0771326}},
{.leaf={.val=0.38506}},
{.pos=1,.cond={.feat=11,.thr=1783}},
{.leaf={.val=0.198242}},
{.leaf={.val=0.303271}},
{.pos=3,.cond={.feat=0,.thr=63}},
{.pos=1,.cond={.feat=6,.cat=238}},
{.leaf={.val=0.310521}},
{.leaf={.val=0.318377}},
{.pos=1,.cond={.feat=1,.cat=252}},
{.leaf={.val=0.222857}},
{.leaf={.val=0.300409}},
};
static const uint8_t condition_types[] = {0,1,0,1,0,1,1,1,1,1,0,0,0,1};
static const uint8_t root_deltas[] = {55,57};
static const std::bitset<260> categorical_bank {"00111011010100000000000110000011011000000000010000000111000001000000000010000000011101101100110000011011001000110001001000001100101101111111111011000001001100000111011011001100000110110000100100101111110100011000000100001110001111111111110011100010010000011001"};
inline Label Predict(const Instance& instance) {
float accumulator {-1.1631};
const Node* root = nodes;
const Node* node;
const char* raw_instance = (const char*)(&instance);
uint8_t eval;
for (uint8_t tree_idx = 0; tree_idx != kNumTrees; tree_idx++) {
node = root;
while(node->pos) {
if (condition_types[node->cond.feat] == 0) {
int32_t numerical_feature;
std::memcpy(&numerical_feature, raw_instance + node->cond.feat * sizeof(int32_t), sizeof(int32_t));
eval = numerical_feature >= node->cond.thr;
} else if (condition_types[node->cond.feat] == 1) {
uint32_t categorical_feature;
std::memcpy(&categorical_feature, raw_instance + node->cond.feat * sizeof(uint32_t), sizeof(uint32_t));
eval = categorical_bank[categorical_feature + node->cond.cat];
} else {
assert(false);
}
node += (node->pos & -eval) + 1;
}
accumulator += node->leaf.val;
root += root_deltas[tree_idx];
}
return static_cast<Label>(accumulator >= 0);
}
} // namespace ydf_model
#endif
In your C++ code, call the model as:
#include "ydf_model.h"
void f() {
using namespace ydf_model;
const Label prediction = Predict(Instance{
.age = 39,
.workclass = FeatureWorkclass::kStateGov,
.fnlwgt = 775,
.education = FeatureEducation::kBachelors,
.education_num = 13,
.marital_status = FeatureMaritalStatus::kNeverMarried,
.occupation = FeatureOccupation::kAdmClerical,
.relationship = FeatureRelationship::kNotInFamily,
.race = FeatureRace::kWhite,
.sex = FeatureSex::kMale,
.capital_gain = 2174,
.capital_loss = 0,
.hours_per_week = 40,
.native_country = FeatureNativeCountry::kUnitedStates,
});
if (prediction==Label::kLt50K){
// ...
} else if (prediction==Label::kGt50K) {
// ...
}
}
By default, Predict returns a class for classification model. Instead, the
method can return a probability (or probabilities in case of multi-class) or
scores (e.g., logits) with the classification_output argument. For example:
model.to_standalone_cc(classification_output='PROBABILITY'): Returns a probabilitiy (float) or probabilities (std::array<float>).model.to_standalone_cc(classification_output='SCORE'): Returns scores.
Categorical feature values are created from the corresponding enum class e.g. FeatureRelationship::kNotInFamily. While it is less efficient and can lead to larger binary, categorical values can also be created from a string e.g. FeatureRelationshipFromString("Not-In-Family"). The "*FromString" symbols are generated if the model is exported with categorical_from_string=True.
Note: If a string does not match an existing categorical values, the kOutOfVocabulary value is returned.
{
.age = 39, \
.workclass = FeatureWorkclassFromString("State-gov"), \
.fnlwgt = 77516, \
.education = FeatureEducationFromString("Bachelors"), \
...
}
If you look at the content of the Predict function, you will see a for-loop
over the trees and a while-loop over the nodes. This is called the "routing"
algorithm, and it is a simple and generally efficient way to generate
predictions with a decision forest.
Other algorithms are available with the algorithm argument. For example, the
code generated with algorithm="IF_ELSE" will be a succession of imbricated
if-else statements.
In the following cell, check the content of the Predict function at the
bottom
print(model.to_standalone_cc(algorithm="IF_ELSE"))
#ifndef YDF_MODEL_YDF_MODEL_H_
#define YDF_MODEL_YDF_MODEL_H_
#include <stdint.h>
#include <cstring>
#include <array>
#include <algorithm>
#include <bitset>
#include <cassert>
namespace ydf_model {
enum class Label : uint32_t {
kLt50K = 0,
kGt50K = 1,
};
enum class FeatureWorkclass : uint32_t {
kOutOfVocabulary = 0,
kPrivate = 1,
kSelfEmpNotInc = 2,
kLocalGov = 3,
kStateGov = 4,
kSelfEmpInc = 5,
kFederalGov = 6,
kWithoutPay = 7,
};
enum class FeatureEducation : uint32_t {
kOutOfVocabulary = 0,
kHsGrad = 1,
kSomeCollege = 2,
kBachelors = 3,
kMasters = 4,
kAssocVoc = 5,
k11th = 6,
kAssocAcdm = 7,
k10th = 8,
k7th8th = 9,
kProfSchool = 10,
k9th = 11,
k12th = 12,
kDoctorate = 13,
k5th6th = 14,
k1st4th = 15,
kPreschool = 16,
};
enum class FeatureMaritalStatus : uint32_t {
kOutOfVocabulary = 0,
kMarriedCivSpouse = 1,
kNeverMarried = 2,
kDivorced = 3,
kWidowed = 4,
kSeparated = 5,
kMarriedSpouseAbsent = 6,
kMarriedAfSpouse = 7,
};
enum class FeatureOccupation : uint32_t {
kOutOfVocabulary = 0,
kProfSpecialty = 1,
kExecManagerial = 2,
kCraftRepair = 3,
kAdmClerical = 4,
kSales = 5,
kOtherService = 6,
kMachineOpInspct = 7,
kTransportMoving = 8,
kHandlersCleaners = 9,
kFarmingFishing = 10,
kTechSupport = 11,
kProtectiveServ = 12,
kPrivHouseServ = 13,
};
enum class FeatureRelationship : uint32_t {
kOutOfVocabulary = 0,
kHusband = 1,
kNotInFamily = 2,
kOwnChild = 3,
kUnmarried = 4,
kWife = 5,
kOtherRelative = 6,
};
enum class FeatureRace : uint32_t {
kOutOfVocabulary = 0,
kWhite = 1,
kBlack = 2,
kAsianPacIslander = 3,
kAmerIndianEskimo = 4,
kOther = 5,
};
enum class FeatureSex : uint32_t {
kOutOfVocabulary = 0,
kMale = 1,
kFemale = 2,
};
enum class FeatureNativeCountry : uint32_t {
kOutOfVocabulary = 0,
kUnitedStates = 1,
kMexico = 2,
kPhilippines = 3,
kGermany = 4,
kCanada = 5,
kPuertoRico = 6,
kIndia = 7,
kElSalvador = 8,
kCuba = 9,
kEngland = 10,
kJamaica = 11,
kDominicanRepublic = 12,
kSouth = 13,
kChina = 14,
kItaly = 15,
kColumbia = 16,
kGuatemala = 17,
kJapan = 18,
kVietnam = 19,
kTaiwan = 20,
kIran = 21,
kPoland = 22,
kHaiti = 23,
kNicaragua = 24,
kGreece = 25,
kPortugal = 26,
kEcuador = 27,
kFrance = 28,
kPeru = 29,
kThailand = 30,
kCambodia = 31,
kIreland = 32,
kLaos = 33,
kYugoslavia = 34,
kTrinadadTobago = 35,
kHonduras = 36,
kHong = 37,
kHungary = 38,
kScotland = 39,
kOutlyingUsGuamUsviEtc = 40,
};
constexpr const int kNumFeatures = 14;
constexpr const int kNumTrees = 2;
struct Instance {
typedef int32_t Numerical;
Numerical age;
FeatureWorkclass workclass;
Numerical fnlwgt;
FeatureEducation education;
Numerical education_num;
FeatureMaritalStatus marital_status;
FeatureOccupation occupation;
FeatureRelationship relationship;
FeatureRace race;
FeatureSex sex;
Numerical capital_gain;
Numerical capital_loss;
Numerical hours_per_week;
FeatureNativeCountry native_country;
};
inline Label Predict(const Instance& instance) {
float accumulator {-1.1631};
// Tree #0
if (instance.relationship == FeatureRelationship::kOutOfVocabulary ||
instance.relationship == FeatureRelationship::kHusband ||
instance.relationship == FeatureRelationship::kWife) {
if (instance.education_num >= 12.5) {
if (instance.capital_gain >= 5095.5) {
if (std::array<FeatureOccupation,12> mask = { FeatureOccupation::kOutOfVocabulary, FeatureOccupation::kProfSpecialty, FeatureOccupation::kExecManagerial, FeatureOccupation::kCraftRepair, FeatureOccupation::kAdmClerical, FeatureOccupation::kSales, FeatureOccupation::kOtherService, FeatureOccupation::kMachineOpInspct, FeatureOccupation::kTransportMoving, FeatureOccupation::kHandlersCleaners, FeatureOccupation::kTechSupport, FeatureOccupation::kProtectiveServ};
std::binary_search(mask.begin(), mask.end(), instance.occupation)) {
if (instance.age >= 79.5) {
accumulator += 0.309737;
} else {
accumulator += 0.418684;
}
} else {
accumulator += 0.309737;
}
} else {
if (instance.capital_loss >= 1782.5) {
if (instance.capital_loss >= 1989.5) {
accumulator += 0.349312;
} else {
accumulator += 0.417359;
}
} else {
if (instance.occupation == FeatureOccupation::kProfSpecialty ||
instance.occupation == FeatureOccupation::kExecManagerial ||
instance.occupation == FeatureOccupation::kSales ||
instance.occupation == FeatureOccupation::kTechSupport ||
instance.occupation == FeatureOccupation::kProtectiveServ) {
accumulator += 0.253437;
} else {
accumulator += 0.11557;
}
}
}
} else {
if (instance.capital_gain >= 5095.5) {
if (instance.age >= 60.5) {
if (instance.occupation == FeatureOccupation::kProfSpecialty ||
instance.occupation == FeatureOccupation::kExecManagerial ||
instance.occupation == FeatureOccupation::kAdmClerical ||
instance.occupation == FeatureOccupation::kSales ||
instance.occupation == FeatureOccupation::kMachineOpInspct ||
instance.occupation == FeatureOccupation::kTransportMoving ||
instance.occupation == FeatureOccupation::kHandlersCleaners) {
accumulator += 0.397934;
} else {
accumulator += 0.205614;
}
} else {
if (instance.fnlwgt >= 36212.5) {
accumulator += 0.419984;
} else {
accumulator += 0.419984;
}
}
} else {
if (instance.occupation == FeatureOccupation::kProfSpecialty ||
instance.occupation == FeatureOccupation::kExecManagerial ||
instance.occupation == FeatureOccupation::kAdmClerical ||
instance.occupation == FeatureOccupation::kSales ||
instance.occupation == FeatureOccupation::kTechSupport ||
instance.occupation == FeatureOccupation::kProtectiveServ) {
if (instance.age >= 33.5) {
accumulator += 0.132992;
} else {
accumulator += 0.00826457;
}
} else {
if (std::array<FeatureEducation,9> mask = { FeatureEducation::kOutOfVocabulary, FeatureEducation::kHsGrad, FeatureEducation::kSomeCollege, FeatureEducation::kBachelors, FeatureEducation::kMasters, FeatureEducation::kAssocVoc, FeatureEducation::kAssocAcdm, FeatureEducation::kProfSchool, FeatureEducation::kDoctorate};
std::binary_search(mask.begin(), mask.end(), instance.education)) {
accumulator += 0.00969668;
} else {
accumulator += -0.0813718;
}
}
}
}
} else {
if (instance.capital_gain >= 7073.5) {
if (instance.age >= 21.5) {
if (instance.capital_gain >= 7565.5) {
if (instance.capital_gain >= 30961.5) {
accumulator += 0.392422;
} else {
accumulator += 0.419984;
}
} else {
if (instance.education == FeatureEducation::kBachelors ||
instance.education == FeatureEducation::kMasters ||
instance.education == FeatureEducation::kAssocVoc ||
instance.education == FeatureEducation::kProfSchool) {
accumulator += 0.419984;
} else {
accumulator += -0.0210046;
}
}
} else {
accumulator += 0.0892425;
}
} else {
if (instance.education == FeatureEducation::kOutOfVocabulary ||
instance.education == FeatureEducation::kBachelors ||
instance.education == FeatureEducation::kMasters ||
instance.education == FeatureEducation::kProfSchool ||
instance.education == FeatureEducation::kDoctorate) {
if (instance.age >= 31.5) {
if (std::array<FeatureEducation,15> mask = { FeatureEducation::kOutOfVocabulary, FeatureEducation::kHsGrad, FeatureEducation::kSomeCollege, FeatureEducation::kAssocVoc, FeatureEducation::k11th, FeatureEducation::kAssocAcdm, FeatureEducation::k10th, FeatureEducation::k7th8th, FeatureEducation::kProfSchool, FeatureEducation::k9th, FeatureEducation::k12th, FeatureEducation::kDoctorate, FeatureEducation::k5th6th, FeatureEducation::k1st4th, FeatureEducation::kPreschool};
std::binary_search(mask.begin(), mask.end(), instance.education)) {
accumulator += 0.16421;
} else {
accumulator += -0.0295298;
}
} else {
if (instance.capital_loss >= 1977) {
accumulator += 0.19949;
} else {
accumulator += -0.106976;
}
}
} else {
if (instance.capital_loss >= 2218.5) {
if (instance.fnlwgt >= 125450) {
accumulator += -0.0328167;
} else {
accumulator += 0.292776;
}
} else {
if (instance.hours_per_week >= 40.5) {
accumulator += -0.0927111;
} else {
accumulator += -0.123347;
}
}
}
}
}
// Tree #1
if (instance.relationship == FeatureRelationship::kOutOfVocabulary ||
instance.relationship == FeatureRelationship::kHusband ||
instance.relationship == FeatureRelationship::kWife) {
if (instance.education_num >= 12.5) {
if (instance.capital_gain >= 5095.5) {
if (instance.age >= 62.5) {
if (instance.workclass == FeatureWorkclass::kOutOfVocabulary ||
instance.workclass == FeatureWorkclass::kPrivate ||
instance.workclass == FeatureWorkclass::kLocalGov ||
instance.workclass == FeatureWorkclass::kStateGov ||
instance.workclass == FeatureWorkclass::kSelfEmpInc) {
accumulator += 0.300409;
} else {
accumulator += 0.222857;
}
} else {
if (instance.occupation == FeatureOccupation::kFarmingFishing ||
instance.occupation == FeatureOccupation::kProtectiveServ) {
accumulator += 0.318377;
} else {
accumulator += 0.310521;
}
}
} else {
if (instance.occupation == FeatureOccupation::kProfSpecialty ||
instance.occupation == FeatureOccupation::kExecManagerial ||
instance.occupation == FeatureOccupation::kAdmClerical ||
instance.occupation == FeatureOccupation::kSales ||
instance.occupation == FeatureOccupation::kTechSupport ||
instance.occupation == FeatureOccupation::kProtectiveServ) {
if (instance.capital_loss >= 1782.5) {
accumulator += 0.303271;
} else {
accumulator += 0.198242;
}
} else {
if (instance.native_country == FeatureNativeCountry::kGermany ||
instance.native_country == FeatureNativeCountry::kItaly ||
instance.native_country == FeatureNativeCountry::kIran ||
instance.native_country == FeatureNativeCountry::kPoland ||
instance.native_country == FeatureNativeCountry::kHaiti ||
instance.native_country == FeatureNativeCountry::kCambodia) {
accumulator += 0.38506;
} else {
accumulator += 0.0771326;
}
}
}
} else {
if (instance.capital_gain >= 5095.5) {
if (instance.age >= 60.5) {
if (instance.occupation == FeatureOccupation::kProfSpecialty ||
instance.occupation == FeatureOccupation::kExecManagerial ||
instance.occupation == FeatureOccupation::kAdmClerical ||
instance.occupation == FeatureOccupation::kSales ||
instance.occupation == FeatureOccupation::kMachineOpInspct ||
instance.occupation == FeatureOccupation::kTransportMoving ||
instance.occupation == FeatureOccupation::kHandlersCleaners) {
accumulator += 0.296477;
} else {
accumulator += 0.166497;
}
} else {
if (instance.fnlwgt >= 45793.5) {
accumulator += 0.310248;
} else {
accumulator += 0.310248;
}
}
} else {
if (instance.occupation == FeatureOccupation::kProfSpecialty ||
instance.occupation == FeatureOccupation::kExecManagerial ||
instance.occupation == FeatureOccupation::kAdmClerical ||
instance.occupation == FeatureOccupation::kSales ||
instance.occupation == FeatureOccupation::kTechSupport ||
instance.occupation == FeatureOccupation::kProtectiveServ) {
if (instance.capital_loss >= 1782.5) {
accumulator += 0.291468;
} else {
accumulator += 0.0773945;
}
} else {
if (instance.capital_loss >= 1794) {
accumulator += 0.192961;
} else {
accumulator += -0.0159635;
}
}
}
}
} else {
if (instance.capital_gain >= 7073.5) {
if (instance.age >= 21.5) {
if (instance.capital_gain >= 7565.5) {
if (instance.capital_gain >= 30961.5) {
accumulator += 0.293003;
} else {
accumulator += 0.310248;
}
} else {
if (instance.education_num >= 10.5) {
accumulator += 0.310248;
} else {
accumulator += -0.0191264;
}
}
} else {
accumulator += 0.0765646;
}
} else {
if (instance.education == FeatureEducation::kOutOfVocabulary ||
instance.education == FeatureEducation::kBachelors ||
instance.education == FeatureEducation::kMasters ||
instance.education == FeatureEducation::kProfSchool ||
instance.education == FeatureEducation::kDoctorate) {
if (instance.age >= 31.5) {
if (instance.education_num >= 14.5) {
accumulator += 0.13565;
} else {
accumulator += -0.0270186;
}
} else {
if (instance.hours_per_week >= 44.5) {
accumulator += -0.0633502;
} else {
accumulator += -0.114477;
}
}
} else {
if (instance.capital_loss >= 2218.5) {
if (instance.fnlwgt >= 125450) {
accumulator += -0.0300817;
} else {
accumulator += 0.227849;
}
} else {
if (instance.hours_per_week >= 40.5) {
accumulator += -0.0879641;
} else {
accumulator += -0.11917;
}
}
}
}
}
return static_cast<Label>(accumulator >= 0);
}
} // namespace ydf_model
#endif
The data type (dtype) of numerical features in your training dataset affects your compiled model's size. A model trained with int16 or int8 numerical features will be smaller than one trained with int32 or float values. In the next example, we'll cast the training dataset to a smaller data type to get a smaller model.
Note: To be effective, all the numerical features need to be casted.
# Before casting
train_ds.dtypes
age int64 workclass object fnlwgt int64 education object education_num int64 marital_status object occupation object relationship object race object sex object capital_gain int64 capital_loss int64 hours_per_week int64 native_country object income object dtype: object
casted_train_ds = train_ds.copy()
for col in casted_train_ds.columns:
if casted_train_ds[col].dtype in ["int32", "int64"]:
casted_train_ds[col] = casted_train_ds[col].astype("int16")
# After casting
casted_train_ds.dtypes
age int16 workclass object fnlwgt int16 education object education_num int16 marital_status object occupation object relationship object race object sex object capital_gain int16 capital_loss int16 hours_per_week int16 native_country object income object dtype: object
Build Rule Integration¶
Instead of saving manually the result of model.to_standalone_cc() to a file,
you can use the cc_ydf_standalone_model Blaze/Bazel rule. The steps are:
1.
Save the model with model.save(...) in a new directory in your source code
(e.g., in Google3).
model.save("my_project/ydf_model_data")
2.
Create a BUILD file with a filegroup in the model directory:
File: my_project/ydf_model_data/BUILD
filegroup(name = "ydf_model_data", srcs = glob(["**"]))
3.
In your library's BUILD, create a cc_ydf_standalone_model build rule.
File: my_project/BUILD
load("//third_party/yggdrasil_decision_forests/serving/embed:embed.bzl", "cc_ydf_standalone_model ")
cc_ydf_standalone_model (
name = "ydf_model", # Rule name, .h filename, and namespace in the .h file.
data = "//my_project/ydf_model_data",
# Compilation options here.
classification_output = "PROBABILITY",
)
4.
In your cc_binary or cc_library, add ":my_model" as a dependency.
File: my_project/BUILD python cc_binary( name = "main", srcs = ["main.cc"], deps = [":ydf_model"], )
5.
In your C++ code, include and call the model:
#include "my_project/ydf_model.h"
using namespace ydf_model;
const Label prediction = Prediction(Instance{.f1=5, f2=F2:kRed});