pip install ydf scikit-learn plotly -U
import ydf # Yggdrasil Decision Forests
import numpy as np
import pandas as pd # We use Pandas to load small datasets
What are counterfactual examples?¶
Counterfactual examples are the most similar training examples to a prediction according to a model. Counterfactual examples can be used to explain the model's prediction and identify potential issues in the training dataset. Counterfactual examples are found using an example distance.
What is an example distance?¶
Decision forest models define an implicit measure of proximity or similarity between two examples, referred to as distance. The distance represents how two examples are treated similarly in the model. Informally, two examples are close if they are of the same class and for the same reasons.
This distance is useful for understanding models and their predictions. For example, we can use it for clustering, manifold learning, or simply to look at the training examples that are nearest to a test example (called counterfactual examples). This can help us to understand why the model made its predictions.
Keep in mind that a decision forest's distance measure is just one of many reasonable distance metrics on a dataset. One of its many advantages is that allows comparing features on different scales and with different semantics.
In this notebook, we will train a model and use its distance to:
Find training examples that are neighbors of a test example and use them to explain the model's predictions.
Map all the examples onto an interactive two-dimensional plot (also known as a 2D manifold) and automatically detect two-dimensional clusters of examples that behave similarly.
Apply hierarchical clustering to explain how the model works as a whole.
The More You Know: Leo Breiman, the author of the random forest learning algorithm, proposed a method to measure the proximity between two examples using a pre-trained Random Forest (RF) model. He qualifies this method as "[...] one of the most useful tools in random forests.". When using Random Forest models, this is the distance used by YDF.
Find closest training examples to a test example¶
Let's download a classification dataset.
ds_path = "https://raw.githubusercontent.com/google/yggdrasil-decision-forests/main/yggdrasil_decision_forests/test_data/dataset"
train_ds = pd.read_csv(f"{ds_path}/adult_train.csv")
test_ds = pd.read_csv(f"{ds_path}/adult_test.csv")
# Print the first 5 training examples
train_ds.head(5)
age | workclass | fnlwgt | education | education_num | marital_status | occupation | relationship | race | sex | capital_gain | capital_loss | hours_per_week | native_country | income | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 44 | Private | 228057 | 7th-8th | 4 | Married-civ-spouse | Machine-op-inspct | Wife | White | Female | 0 | 0 | 40 | Dominican-Republic | <=50K |
1 | 20 | Private | 299047 | Some-college | 10 | Never-married | Other-service | Not-in-family | White | Female | 0 | 0 | 20 | United-States | <=50K |
2 | 40 | Private | 342164 | HS-grad | 9 | Separated | Adm-clerical | Unmarried | White | Female | 0 | 0 | 37 | United-States | <=50K |
3 | 30 | Private | 361742 | Some-college | 10 | Married-civ-spouse | Exec-managerial | Husband | White | Male | 0 | 0 | 50 | United-States | <=50K |
4 | 67 | Self-emp-inc | 171564 | HS-grad | 9 | Married-civ-spouse | Prof-specialty | Wife | White | Female | 20051 | 0 | 30 | England | >50K |
We train a random forest on this dataset.
model = ydf.RandomForestLearner(label="income").train(train_ds)
Train model on 22792 examples Model trained in 0:00:00.931491
We need to select a example to explain. Let's select the first example of the testing dataset.
selected_example = test_ds[:1]
selected_example
age | workclass | fnlwgt | education | education_num | marital_status | occupation | relationship | race | sex | capital_gain | capital_loss | hours_per_week | native_country | income | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 39 | State-gov | 77516 | Bachelors | 13 | Never-married | Adm-clerical | Not-in-family | White | Male | 2174 | 0 | 40 | United-States | <=50K |
On this example, the model predicts:
model.predict(selected_example)
array([0.01], dtype=float32)
In other words, the negative class <=50K
with $1-0.01=99\%$ probability.
Now, we compute the distance between the selected test example and all the training examples.
distances = model.distance(train_ds, selected_example).squeeze()
print("distances:", distances)
distances: [1. 1. 1. ... 0.99333334 0.99666667 1. ]
Let's find the the five training examples with smallest distance to our chosen example.
close_train_idxs = np.argsort(distances)[:5]
print("close_train_idxs:", close_train_idxs)
print("Selected test examples:")
train_ds.iloc[close_train_idxs]
close_train_idxs: [16596 21845 10321 7299 14721] Selected test examples:
age | workclass | fnlwgt | education | education_num | marital_status | occupation | relationship | race | sex | capital_gain | capital_loss | hours_per_week | native_country | income | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
16596 | 41 | State-gov | 26892 | Bachelors | 13 | Never-married | Adm-clerical | Not-in-family | White | Male | 0 | 0 | 40 | United-States | <=50K |
21845 | 37 | State-gov | 60227 | Bachelors | 13 | Never-married | Adm-clerical | Not-in-family | White | Male | 0 | 0 | 38 | United-States | <=50K |
10321 | 40 | Private | 82161 | Bachelors | 13 | Never-married | Adm-clerical | Not-in-family | White | Male | 0 | 0 | 40 | United-States | <=50K |
7299 | 30 | State-gov | 158291 | Bachelors | 13 | Never-married | Adm-clerical | Not-in-family | White | Male | 0 | 0 | 40 | United-States | <=50K |
14721 | 32 | State-gov | 171111 | Bachelors | 13 | Never-married | Adm-clerical | Not-in-family | White | Male | 0 | 0 | 37 | United-States | <=50K |
Observations:
- For the chosen example, the model predicted class
<=50K
. For the five closes examples, the model had the same prediction. - The closest examples share many features values, such as
education
,marital status
,occupation
,race
, and working between 37 and 40hours per week
. This explains well why these examples are close to each other. - The examples'
age
s range between 30 and 40, meaning the model sees this age range as equivalent for those examples.
import plotly.graph_objs as go
from plotly.offline import iplot # For interactive plots
from sklearn.manifold import TSNE # For 2d projections
import plotly.io as pio
pio.renderers.default="colab"
# Pairwise distance between all testing examples
distances = model.distance(test_ds, test_ds)
# Find 2d projection
t_sne = TSNE(
# Number of dimensions to display. 3d is also possible.
n_components=2,
# Control the shape of the projection. Higher values create more
# distinct but also more collapsed clusters. Can be in 5-50.
perplexity=20,
metric="precomputed",
init="random",
verbose=1,
learning_rate="auto",
).fit_transform(distances)
[t-SNE] Computing 61 nearest neighbors... [t-SNE] Indexed 9769 samples in 0.055s... [t-SNE] Computed neighbors for 9769 samples in 0.753s... [t-SNE] Computed conditional probabilities for sample 1000 / 9769 [t-SNE] Computed conditional probabilities for sample 2000 / 9769 [t-SNE] Computed conditional probabilities for sample 3000 / 9769 [t-SNE] Computed conditional probabilities for sample 4000 / 9769 [t-SNE] Computed conditional probabilities for sample 5000 / 9769 [t-SNE] Computed conditional probabilities for sample 6000 / 9769 [t-SNE] Computed conditional probabilities for sample 7000 / 9769 [t-SNE] Computed conditional probabilities for sample 8000 / 9769 [t-SNE] Computed conditional probabilities for sample 9000 / 9769 [t-SNE] Computed conditional probabilities for sample 9769 / 9769 [t-SNE] Mean sigma: 0.178857 [t-SNE] KL divergence after 250 iterations with early exaggeration: 75.775681 [t-SNE] KL divergence after 1000 iterations: 1.117051
Let's create an interactive plot with the example features.
def example_to_html(example):
return "<br>".join([f"<b>{k}:</b> {v}" for k, v in example.items()])
def interactive_plot(dataset, projections):
colors = (dataset["income"] == ">50K").map(lambda x: ["red", "blue"][x])
labels = list(dataset.apply(example_to_html, axis=1).values)
args = {
"data": [
go.Scatter(
x=projections[:, 0],
y=projections[:, 1],
text=labels,
mode="markers",
marker={"color": colors, "size": 3},
)
],
"layout": go.Layout(width=500, height=500, template="simple_white"),
}
iplot(args)
interactive_plot(test_ds, t_sne)
".join([f"{k}: {v}" for k, v in example.items()]) def interactive_plot(dataset, projections): colors = (dataset["income"] == ">50K").map(lambda x: ["red", "blue"][x]) labels = list(dataset.apply(example_to_html, axis=1).values) args = { "data": [ go.Scatter( x=projections[:, 0], y=projections[:, 1], text=labels, mode="markers", marker={"color": colors, "size": 3}, ) ], "layout": go.Layout(width=500, height=500, template="simple_white"), } iplot(args) interactive_plot(test_ds, t_sne)
Note: Move your mouse over the plot to see the values of the examples.
The colors represent the labels. We can see clusters of uniform colors (clusters where all the labels are the same), and clusters of mixed colors (clusters where the model has difficulty making good predictions).
Can you make sense of those clusters?
Cluster examples¶
We can also cluster examples. Many methods are available. Let's use AgglomerativeClustering
.
from sklearn.cluster import AgglomerativeClustering
num_clusters = 6
clustering = AgglomerativeClustering(
n_clusters=num_clusters,
metric="precomputed",
linkage="average",
).fit(distances)
Next, we print the statistics of the features and one example in each cluster.
import IPython
for cluster_idx in range(num_clusters):
selected_examples = test_ds[clustering.labels_ == cluster_idx]
print(f"Cluster #{cluster_idx} with {len(selected_examples)} examples")
print("=============================")
IPython.display.display(selected_examples.describe())
IPython.display.display(selected_examples.iloc[:1])
Cluster #0 with 2879 examples =============================
age | fnlwgt | education_num | capital_gain | capital_loss | hours_per_week | |
---|---|---|---|---|---|---|
count | 2879.000000 | 2879.000000 | 2879.000000 | 2879.000000 | 2879.000000 | 2879.000000 |
mean | 42.860021 | 184706.465439 | 8.879125 | 200.963876 | 32.425842 | 42.555054 |
std | 12.426582 | 99424.684674 | 1.929070 | 852.256462 | 231.362238 | 11.910265 |
min | 18.000000 | 19395.000000 | 1.000000 | 0.000000 | 0.000000 | 1.000000 |
25% | 33.000000 | 115465.500000 | 9.000000 | 0.000000 | 0.000000 | 40.000000 |
50% | 41.000000 | 176681.000000 | 9.000000 | 0.000000 | 0.000000 | 40.000000 |
75% | 51.000000 | 231872.500000 | 10.000000 | 0.000000 | 0.000000 | 46.000000 |
max | 90.000000 | 671292.000000 | 12.000000 | 5013.000000 | 2179.000000 | 99.000000 |
age | workclass | fnlwgt | education | education_num | marital_status | occupation | relationship | race | sex | capital_gain | capital_loss | hours_per_week | native_country | income | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
1 | 40 | Private | 121772 | Assoc-voc | 11 | Married-civ-spouse | Craft-repair | Husband | Asian-Pac-Islander | Male | 0 | 0 | 40 | NaN | >50K |
Cluster #1 with 5131 examples =============================
age | fnlwgt | education_num | capital_gain | capital_loss | hours_per_week | |
---|---|---|---|---|---|---|
count | 5131.000000 | 5.131000e+03 | 5131.000000 | 5131.000000 | 5131.000000 | 5131.000000 |
mean | 34.026895 | 1.931768e+05 | 9.726954 | 103.289222 | 57.424479 | 37.824401 |
std | 13.371512 | 1.055196e+05 | 2.395434 | 642.138022 | 328.764194 | 12.401540 |
min | 17.000000 | 1.921400e+04 | 1.000000 | 0.000000 | 0.000000 | 1.000000 |
25% | 23.000000 | 1.205865e+05 | 9.000000 | 0.000000 | 0.000000 | 35.000000 |
50% | 31.000000 | 1.817210e+05 | 10.000000 | 0.000000 | 0.000000 | 40.000000 |
75% | 42.000000 | 2.416855e+05 | 11.000000 | 0.000000 | 0.000000 | 40.000000 |
max | 90.000000 | 1.038553e+06 | 16.000000 | 7443.000000 | 3770.000000 | 99.000000 |
age | workclass | fnlwgt | education | education_num | marital_status | occupation | relationship | race | sex | capital_gain | capital_loss | hours_per_week | native_country | income | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 39 | State-gov | 77516 | Bachelors | 13 | Never-married | Adm-clerical | Not-in-family | White | Male | 2174 | 0 | 40 | United-States | <=50K |
Cluster #2 with 220 examples =============================
age | fnlwgt | education_num | capital_gain | capital_loss | hours_per_week | |
---|---|---|---|---|---|---|
count | 220.000000 | 220.000000 | 220.000000 | 220.0 | 220.000000 | 220.000000 |
mean | 44.863636 | 182932.690909 | 11.977273 | 0.0 | 1996.745455 | 46.700000 |
std | 11.372463 | 89132.990647 | 2.314227 | 0.0 | 174.160632 | 11.490357 |
min | 22.000000 | 20953.000000 | 9.000000 | 0.0 | 1825.000000 | 12.000000 |
25% | 36.750000 | 125575.250000 | 10.000000 | 0.0 | 1887.000000 | 40.000000 |
50% | 43.000000 | 169627.500000 | 13.000000 | 0.0 | 1902.000000 | 41.000000 |
75% | 51.000000 | 213384.500000 | 14.000000 | 0.0 | 1977.000000 | 50.000000 |
max | 83.000000 | 530099.000000 | 16.000000 | 0.0 | 2603.000000 | 99.000000 |
age | workclass | fnlwgt | education | education_num | marital_status | occupation | relationship | race | sex | capital_gain | capital_loss | hours_per_week | native_country | income | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
25 | 48 | Self-emp-not-inc | 191277 | Doctorate | 16 | Married-civ-spouse | Prof-specialty | Husband | White | Male | 0 | 1902 | 60 | United-States | >50K |
Cluster #3 with 1012 examples =============================
age | fnlwgt | education_num | capital_gain | capital_loss | hours_per_week | |
---|---|---|---|---|---|---|
count | 1012.000000 | 1.012000e+03 | 1012.000000 | 1012.000000 | 1012.000000 | 1012.000000 |
mean | 43.610672 | 1.862960e+05 | 13.541502 | 119.073123 | 14.341897 | 44.180830 |
std | 11.334174 | 1.074333e+05 | 0.874553 | 675.355176 | 152.606687 | 12.260441 |
min | 23.000000 | 2.232800e+04 | 13.000000 | 0.000000 | 0.000000 | 1.000000 |
25% | 35.000000 | 1.148158e+05 | 13.000000 | 0.000000 | 0.000000 | 40.000000 |
50% | 43.000000 | 1.756480e+05 | 13.000000 | 0.000000 | 0.000000 | 40.000000 |
75% | 50.000000 | 2.301228e+05 | 14.000000 | 0.000000 | 0.000000 | 50.000000 |
max | 90.000000 | 1.097453e+06 | 16.000000 | 5013.000000 | 1977.000000 | 99.000000 |
age | workclass | fnlwgt | education | education_num | marital_status | occupation | relationship | race | sex | capital_gain | capital_loss | hours_per_week | native_country | income | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
2 | 40 | Private | 193524 | Doctorate | 16 | Married-civ-spouse | Prof-specialty | Husband | White | Male | 0 | 0 | 60 | United-States | >50K |
Cluster #4 with 46 examples =============================
age | fnlwgt | education_num | capital_gain | capital_loss | hours_per_week | |
---|---|---|---|---|---|---|
count | 46.000000 | 46.000000 | 46.000000 | 46.000000 | 46.000000 | 46.000000 |
mean | 47.913043 | 171906.630435 | 15.543478 | 280.413043 | 252.260870 | 43.021739 |
std | 10.897148 | 81143.023865 | 0.503610 | 1088.008529 | 679.865949 | 16.043675 |
min | 32.000000 | 33155.000000 | 15.000000 | 0.000000 | 0.000000 | 6.000000 |
25% | 39.000000 | 115998.000000 | 15.000000 | 0.000000 | 0.000000 | 40.000000 |
50% | 48.500000 | 163298.000000 | 16.000000 | 0.000000 | 0.000000 | 40.000000 |
75% | 53.000000 | 211152.250000 | 16.000000 | 0.000000 | 0.000000 | 50.000000 |
max | 79.000000 | 345259.000000 | 16.000000 | 4787.000000 | 2824.000000 | 99.000000 |
age | workclass | fnlwgt | education | education_num | marital_status | occupation | relationship | race | sex | capital_gain | capital_loss | hours_per_week | native_country | income | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
618 | 36 | Private | 103110 | Doctorate | 16 | Never-married | Prof-specialty | Not-in-family | White | Male | 0 | 0 | 40 | England | <=50K |
Cluster #5 with 481 examples =============================
age | fnlwgt | education_num | capital_gain | capital_loss | hours_per_week | |
---|---|---|---|---|---|---|
count | 481.000000 | 481.000000 | 481.000000 | 481.000000 | 481.0 | 481.000000 |
mean | 45.621622 | 191274.322245 | 11.806653 | 19103.544699 | 0.0 | 46.636175 |
std | 11.010141 | 103664.004053 | 2.507912 | 25872.337100 | 0.0 | 11.647901 |
min | 20.000000 | 19302.000000 | 1.000000 | 5178.000000 | 0.0 | 2.000000 |
25% | 38.000000 | 119793.000000 | 10.000000 | 7298.000000 | 0.0 | 40.000000 |
50% | 44.000000 | 175232.000000 | 13.000000 | 10520.000000 | 0.0 | 45.000000 |
75% | 52.000000 | 235786.000000 | 14.000000 | 15024.000000 | 0.0 | 50.000000 |
max | 78.000000 | 617021.000000 | 16.000000 | 99999.000000 | 0.0 | 99.000000 |
age | workclass | fnlwgt | education | education_num | marital_status | occupation | relationship | race | sex | capital_gain | capital_loss | hours_per_week | native_country | income | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
21 | 44 | Private | 343591 | HS-grad | 9 | Divorced | Craft-repair | Not-in-family | White | Female | 14344 | 0 | 40 | United-States | >50K |