Classification¶
Setup¶
pip install ydf -U
What is Classification?¶
Classification is the task of predicting which category an item belongs to. The set of possible categories (or "classes") is finite. For example, predicting whether an email is SPAM or NOT_SPAM, or identifying a handwritten digit from 0 to 9 are both classification tasks.
The output of a classification model is typically a probability distribution over the possible classes. The final predicted class is the one with the highest probability.
Multi-class classification involves three or more classes (e.g., predicting RED, BLUE, or GREEN).
Binary classification is a special case with only two classes. For simplicity, YDF models return the probability of just one of the classes.
YDF can handle labels that are strings, integers, or boolean values.
Training a classification model¶
In YDF, you specify the model's objective using the task argument in the learner. For classification, the value is ydf.Task.CLASSIFICATION. Since this is the default setting, you often don't need to specify it.
Let's train a model on a real dataset.
# Load necessary libraries
import ydf # Yggdrasil Decision Forests.
import pandas as pd # To load and inspect the dataset.
# Download a classification dataset and load it as a Pandas DataFrame.
ds_path = "https://raw.githubusercontent.com/google/yggdrasil-decision-forests/main/yggdrasil_decision_forests/test_data/dataset"
train_ds = pd.read_csv(f"{ds_path}/adult_train.csv")
test_ds = pd.read_csv(f"{ds_path}/adult_test.csv")
# Display the first 5 rows of the training data.
train_ds.head(5)
| age | workclass | fnlwgt | education | education_num | marital_status | occupation | relationship | race | sex | capital_gain | capital_loss | hours_per_week | native_country | income | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 44 | Private | 228057 | 7th-8th | 4 | Married-civ-spouse | Machine-op-inspct | Wife | White | Female | 0 | 0 | 40 | Dominican-Republic | <=50K |
| 1 | 20 | Private | 299047 | Some-college | 10 | Never-married | Other-service | Not-in-family | White | Female | 0 | 0 | 20 | United-States | <=50K |
| 2 | 40 | Private | 342164 | HS-grad | 9 | Separated | Adm-clerical | Unmarried | White | Female | 0 | 0 | 37 | United-States | <=50K |
| 3 | 30 | Private | 361742 | Some-college | 10 | Married-civ-spouse | Exec-managerial | Husband | White | Male | 0 | 0 | 50 | United-States | <=50K |
| 4 | 67 | Self-emp-inc | 171564 | HS-grad | 9 | Married-civ-spouse | Prof-specialty | Wife | White | Female | 20051 | 0 | 30 | England | >50K |
We'll use the classic "Adult" dataset for this tutorial. The task is binary classification: predict whether an individual's income is >50k or <=50k based on other numerical and categorical features. This dataset also contains missing values, which YDF handles automatically.
The column we want to predict is income:
train_ds["income"]
| income | |
|---|---|
| 0 | <=50K |
| 1 | <=50K |
| 2 | <=50K |
| 3 | <=50K |
| 4 | >50K |
| ... | ... |
| 22787 | <=50K |
| 22788 | >50K |
| 22789 | <=50K |
| 22790 | <=50K |
| 22791 | <=50K |
22792 rows × 1 columns
Now, let's train a Random Forest model to perform this classification task.
# Note: task=ydf.Task.CLASSIFICATION is the default, so it's optional here.
model = ydf.RandomForestLearner(label="income").train(train_ds)
Train model on 22792 examples Model trained in 0:00:01.753511
Classification models are commonly evaluated using metrics like accuracy, confusion matrices, ROC-AUC, and PR-AUC. YDF can compute these metrics and generate a rich evaluation report, including ROC and PR curves.
evaluation = model.evaluate(test_ds)
evaluation
Evaluation of classification models
- Accuracy
- The simplest metric. It's the percentage of predictions that are correct (matching the ground truth).
Example: If a model correctly identifies 90 out of 100 images as cat or dog, the accuracy is 90%. - Confusion Matrix
- A table that shows the counts of:
- True Positives (TP): Model correctly predicted positive.
- True Negatives (TN): Model correctly predicted negative.
- False Positives (FP): Model incorrectly predicted positive (a "false alarm").
- False Negatives (FN): Model incorrectly predicted negative (a "miss").
- Threshold
- YDF classification models predict a probability for each class. A threshold determines the cutoff for classifying something as positive or negative.
Example: If the threshold is 0.5, any prediction above 0.5 might be classified as "spam," and anything below as "not spam." - ROC Curve (Receiver Operating Characteristic Curve)
- A graph that plots the True Positive Rate (TPR) against the False Positive Rate (FPR) at various thresholds.
- TPR (Sensitivity or Recall): TP / (TP + FN) - How many of the actual positives did the model catch?
- FPR: FP / (FP + TN) - How many negatives were incorrectly classified as positives?
Interpretation: A good model has an ROC curve that hugs the top-left corner (high TPR, low FPR). - AUC (Area Under the ROC Curve)
- A single number that summarizes the overall performance shown by the ROC curve. The AUC is a more stable metric than the accuracy. Multi-class classification models evaluate one class against all other classes.
Interpretation: Ranges from 0 to 1. A perfect model has an AUC of 1, while a random model has an AUC of 0.5. Higher is better. - Precision-Recall Curve
- A graph that plots Precision against Recall at various thresholds.
- Precision: TP / (TP + FP) - Out of all the predictions the model labeled as positive, how many were actually positive?
- Recall (same as TPR): TP / (TP + FN) - Out of all the actual positive cases, how many did the model correctly identify?
Interpretation: A good model has a curve that stays high (both high precision and high recall). It is especially useful when dealing with imbalanced datasets (e.g., when one class is much rarer than the other). - PR-AUC (Area Under the Precision-Recall Curve)
- Similar to AUC, but for the Precision-Recall curve. A single number summarizing performance. Multi-class classification models evaluate one class against all other classes. Higher is better.
- Threshold / Accuracy Curve
- A graph that shows how the model's accuracy changes as you vary the classification threshold.
- Threshold / Volume Curve
- A graph showing how the number of data points classified as positive changes as you vary the threshold.
| Label \ Pred | <=50K | >50K |
|---|---|---|
| <=50K | 6962 | 450 |
| >50K | 861 | 1496 |
You can also access individual metrics directly from the evaluation object.
print(evaluation.accuracy)
0.8657999795270754
Making predictions¶
Classification models predict the probability for each potential class.
In binary classification, the predict() method returns the probability of the positive class. You can see the order of classes by checking model.label_classes(). The first class is the "negative" one, the second class is the "positive" one.
# Print the label classes. The second class is the "positive" one.
print(f"Label classes: {model.label_classes()}")
# Predict the probability of the positive class ('>50K').
predictions = model.predict(test_ds)
print(f"Predictions (probabilities):\n{predictions}")
Label classes: ['<=50K', '>50K'] Predictions (probabilities): [0.01333333 0.12999995 0.9499992 ... 0.06000001 0.02333334 0. ]
If you only need the final predicted class label instead of the probabilities, you can use the model.predict_class() method for convenience.
model.predict_class(test_ds)
array(['<=50K', '<=50K', '>50K', ..., '<=50K', '<=50K', '<=50K'],
shape=(9769,), dtype='<U5')
Warning: The order of classes in model.label_classes() depends on the values in the training data. Always check this attribute if your logic depends on the numeric output of predict(). Using predict_class() avoids this ambiguity.