Classification¶
Setup¶
pip install ydf -U
What is classification?¶
Classification is the task of predicting a categorical value, such as an enum, type, or class from a finite set of possible values. For instance, predicting a color from the set of possible colors RED, BLUE, GREEN is a classification task. The output of classification models is a probability distribution over the possible classes. The predicted class is the one with the highest probability.
When there are only two classes, we call it binary classification. In this case, models only return one probability.
Classification labels can be strings, integers, or boolean values.
Training a classification model¶
The task of a model (e.g., classification, regression) is determined by the task
learner argument. The default value of this argument is ydf.Task.CLASSIFICATION
, which means that by default, YDF trains classification models.
# Load libraries
import ydf # Yggdrasil Decision Forests
import pandas as pd # We use Pandas to load small datasets
# Download a classification dataset and load it as a Pandas DataFrame.
ds_path = "https://raw.githubusercontent.com/google/yggdrasil-decision-forests/main/yggdrasil_decision_forests/test_data/dataset"
train_ds = pd.read_csv(f"{ds_path}/adult_train.csv")
test_ds = pd.read_csv(f"{ds_path}/adult_test.csv")
# Print the first 5 training examples
train_ds.head(5)
age | workclass | fnlwgt | education | education_num | marital_status | occupation | relationship | race | sex | capital_gain | capital_loss | hours_per_week | native_country | income | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 44 | Private | 228057 | 7th-8th | 4 | Married-civ-spouse | Machine-op-inspct | Wife | White | Female | 0 | 0 | 40 | Dominican-Republic | <=50K |
1 | 20 | Private | 299047 | Some-college | 10 | Never-married | Other-service | Not-in-family | White | Female | 0 | 0 | 20 | United-States | <=50K |
2 | 40 | Private | 342164 | HS-grad | 9 | Separated | Adm-clerical | Unmarried | White | Female | 0 | 0 | 37 | United-States | <=50K |
3 | 30 | Private | 361742 | Some-college | 10 | Married-civ-spouse | Exec-managerial | Husband | White | Male | 0 | 0 | 50 | United-States | <=50K |
4 | 67 | Self-emp-inc | 171564 | HS-grad | 9 | Married-civ-spouse | Prof-specialty | Wife | White | Female | 20051 | 0 | 30 | England | >50K |
The label column is:
train_ds["income"]
income | |
---|---|
0 | <=50K |
1 | <=50K |
2 | <=50K |
3 | <=50K |
4 | >50K |
... | ... |
22787 | <=50K |
22788 | >50K |
22789 | <=50K |
22790 | <=50K |
22791 | <=50K |
22792 rows × 1 columns
We can train a classification model:
# Note: ydf.Task.CLASSIFICATION is the default value of "task"
model = ydf.RandomForestLearner(label="income",
task=ydf.Task.CLASSIFICATION).train(train_ds)
Train model on 22792 examples Model trained in 0:00:01.808830
Classification models are evaluated using accuracy, confusion matrices, ROC-AUC and PR-AUC. You can plot a rich evaluation with ROC and PR plots.
evaluation = model.evaluate(test_ds)
evaluation
Evaluation of classification models
- Accuracy
- The simplest metric. It's the percentage of predictions that are correct (matching the ground truth).
Example: If a model correctly identifies 90 out of 100 images as cat or dog, the accuracy is 90%. - Confusion Matrix
- A table that shows the counts of:
- True Positives (TP): Model correctly predicted positive.
- True Negatives (TN): Model correctly predicted negative.
- False Positives (FP): Model incorrectly predicted positive (a "false alarm").
- False Negatives (FN): Model incorrectly predicted negative (a "miss").
- Threshold
- YDF classification models predict a probability for each class. A threshold determines the cutoff for classifying something as positive or negative.
Example: If the threshold is 0.5, any prediction above 0.5 might be classified as "spam," and anything below as "not spam." - ROC Curve (Receiver Operating Characteristic Curve)
- A graph that plots the True Positive Rate (TPR) against the False Positive Rate (FPR) at various thresholds.
- TPR (Sensitivity or Recall): TP / (TP + FN) - How many of the actual positives did the model catch?
- FPR: FP / (FP + TN) - How many negatives were incorrectly classified as positives?
Interpretation: A good model has an ROC curve that hugs the top-left corner (high TPR, low FPR). - AUC (Area Under the ROC Curve)
- A single number that summarizes the overall performance shown by the ROC curve. The AUC is a more stable metric than the accuracy. Multi-class classification models evaluate one class against all other classes.
Interpretation: Ranges from 0 to 1. A perfect model has an AUC of 1, while a random model has an AUC of 0.5. Higher is better. - Precision-Recall Curve
- A graph that plots Precision against Recall at various thresholds.
- Precision: TP / (TP + FP) - Out of all the predictions the model labeled as positive, how many were actually positive?
- Recall (same as TPR): TP / (TP + FN) - Out of all the actual positive cases, how many did the model correctly identify?
Interpretation: A good model has a curve that stays high (both high precision and high recall). It is especially useful when dealing with imbalanced datasets (e.g., when one class is much rarer than the other). - PR-AUC (Area Under the Precision-Recall Curve)
- Similar to AUC, but for the Precision-Recall curve. A single number summarizing performance. Multi-class classification models evaluate one class against all other classes. Higher is better.
- Threshold / Accuracy Curve
- A graph that shows how the model's accuracy changes as you vary the classification threshold.
- Threshold / Volume Curve
- A graph showing how the number of data points classified as positive changes as you vary the threshold.
Label \ Pred | <=50K | >50K |
---|---|---|
<=50K | 6962 | 861 |
>50K | 450 | 1496 |
The evaluation metrics can be accessed directly in the evaluation object.
print(evaluation.accuracy)
0.8657999795270754
Making predictions¶
Classification models predict the probability of the label classes.
Binary classification models output the probabiltity of the first class
according to model.label_classes()
.
# Print the label classes.
print(model.label_classes())
# Predict the probability of the first class.
print(model.predict(test_ds))
['<=50K', '>50K'] [0.01333333 0.12999995 0.9499992 ... 0.06000001 0.02333334 0. ]
We can also directly predict the most likely class.
Warning: Always use model.predict_class()
or manually check the order of
classes using model.label_classes()
. Note that the order of label classes may change depending on the training dataset or if YDF is updated.
model.predict_class(test_ds)
array(['<=50K', '<=50K', '>50K', ..., '<=50K', '<=50K', '<=50K'], shape=(9769,), dtype='<U5')