ClassificationĀ¶
SetupĀ¶
pip install ydf -U
What is classification?Ā¶
Classification is the task of predicting a categorical value, such as an enum, type, or class from a finite set of possible values. For instance, predicting a color from the set of possible colors RED, BLUE, GREEN is a classification task. The output of classification models is a probability distribution over the possible classes. The predicted class is the one with the highest probability.
When there are only two classes, we call it binary classification. In this case, models only return one probability.
Classification labels can be strings, integers, or boolean values.
Training a classification modelĀ¶
The task of a model (e.g., classification, regression) is determined by the task
learner argument. The default value of this argument is ydf.Task.CLASSIFICATION
, which means that by default, YDF trains classification models.
# Load libraries
import ydf # Yggdrasil Decision Forests
import pandas as pd # We use Pandas to load small datasets
# Download a classification dataset and load it as a Pandas DataFrame.
ds_path = "https://raw.githubusercontent.com/google/yggdrasil-decision-forests/main/yggdrasil_decision_forests/test_data/dataset"
train_ds = pd.read_csv(f"{ds_path}/adult_train.csv")
test_ds = pd.read_csv(f"{ds_path}/adult_test.csv")
# Print the first 5 training examples
train_ds.head(5)
The label column is:
train_ds["income"]
We can train a classification model:
# Note: ydf.Task.CLASSIFICATION is the default value of "task"
model = ydf.RandomForestLearner(label="income",
task=ydf.Task.CLASSIFICATION).train(train_ds)
Classification models are evaluated using accuracy, confusion matrices, ROC-AUC and PR-AUC. You can plot a rich evaluation with ROC and PR plots.
evaluation = model.evaluate(test_ds)
evaluation
The evaluation metrics can be accessed directly in the evaluation object.
print(evaluation.accuracy)
Making predictionsĀ¶
Classification models predict the probability of the label classes.
Binary classification models output the probabiltity of the first class
according to model.label_classes()
.
# Print the label classes.
print(model.label_classes())
# Predict the probability of the first class.
print(model.predict(test_ds))
We can also directly predict the most likely class.
Warning: Always use model.predict_class()
or manually check the order of
classes using model.label_classes()
. Note that the order of label classes may change depending on the training dataset or if YDF is updated.
model.predict_class(test_ds)