pip install ydf transformers torch jax tqdm datasets scikit-learn matplotlib -U
from itertools import islice
from datasets import load_dataset # The text dataset
import matplotlib.pyplot as plt # For plotting the toy dataset
import numpy as np
from sklearn.decomposition import PCA # PCA is used to reduce the embedding dimension
from sklearn.preprocessing import StandardScaler
import torch
from tqdm import tqdm # For the progress-bar
from transformers import GPT2Model, GPT2Tokenizer # To compute some embeddings
import ydf
What are Vector Sequence features?¶
Vector sequence features are a type of input feature where each value is a sequence (or list) of multi-dimensional, fixed-size numerical vectors. They are well-suited for encoding sets or time series of embeddings, such as embeddings of a collection of images or the embeddings of intermediate layers within a Large Language Model (LLM).
They can be seen as an extension of multi-dimensional numerical features, as illustrated below
Type | Example of value |
---|---|
(Single dimensional) numerical | 4.3 |
Multi-dimensional numerical | [1,5,2] |
Vector sequence numerical | [[1,2,3], [4,5,6], [7,8,9]] |
While the number of vectors within different vector sequence values can vary, all vectors within a given sequence must have the same dimensionality (shape).
About this tutorial¶
This tutorial is divided into two parts. The first part shows how to create sequence features on simple toy data.
The second part shows a more complex example that combines LLM embedding, PCA and vector sequences: We'll use the first hidden layer of the GPT2 model, apply a PCA to reduce its dimensionality (this is an optional step to make the training faster), and use is as a vector sequence feature for text classification with a decision forest.
Part 1: Vector sequence on a simple toy example¶
For in-memory datasets, vector sequences are represented as Python lists of NumPy arrays, where each array has the shape <vector index, vector dimension>.
Note:
- Pandas DataFrames don't work well with multi-dimensional values. A simple Python dictionary of values is simpler and more efficient.
- For file-based datasets, a format capable of representing two-dimensional values is needed. Currently, Avro is the only natively supported format for this.
Our toy dataset is simple: Each feature value is a list of 2D points (between 0 and 5). A sample is labeled "true" if at least one of its points is within the unit circle (centered at (0,0) with a radius of 1), and "false" otherwise. Around 50% of the examples will be positives. Let's build this dataset.
def make_toy_ds(num_examples=1_000):
features = []
labels = []
for _ in range(num_examples):
num_vectors = np.random.randint(0, 5)
vectors = np.random.uniform(-1.5, 1.5, [num_vectors, 2])
label = np.any(np.sum(vectors**2, axis=1) < 1)
features.append(vectors)
labels.append(label)
return {"label": np.array(labels), "feature": features}
# Generate 3 examples
make_toy_ds(num_examples=3)
{'label': array([False, False, True]), 'feature': [array([[ 0.47221051, -1.04686068], [ 1.40348894, -1.00676166], [ 0.63440287, 1.29930153], [-1.28285904, -1.44044944]]), array([[-1.44968141, 1.18143043], [-1.35090648, 1.05524487], [-0.47623758, 1.02229518]]), array([[-0.73034892, -1.03253695], [ 0.47363613, -0.61725529], [-0.02912192, -1.47682402]])]}
Let's plot some examples to make sure the pattern is there.
num_examples = 3
dataset = make_toy_ds(num_examples)
fig, axs = plt.subplots(1, 3, figsize=(10, 3))
for example_idx, ax in enumerate(axs):
feature = dataset["feature"][example_idx]
ax.scatter([v[0] for v in feature], [v[1] for v in feature])
ax.set_title(f"label={dataset['label'][example_idx]}")
ax.set_xlim(-1.5, 1.5)
ax.set_ylim(-1.5, 1.5)
# Show the circle
ax.add_artist(
plt.Circle((0, 0), 1, edgecolor="blue", facecolor="none", linewidth=1)
)
Then, we can train our model.
train_ds = make_toy_ds(num_examples=10_000)
model = ydf.RandomForestLearner(label="label").train(train_ds)
Train model on 10000 examples Model trained in 0:00:05.606750
The following cell shows the model's description. In the Dataspec" tab, you can see the feature statistics (e.g., distribution of the number of vectors, vector dimension). In the Structure tab, you can see the tree conditions that are learned.
For example, the condition "feature" contains X with | X - [0.054303, -0.062462] |² <= 0.996597
evaluates to true iff there is a vector at a distance of less than 0.996 from (0.054303, -0.0624). This is very close to the rule we used to generated the dataset: a distance of less than 1. from (0., 0.).
model.describe()
Task : CLASSIFICATION
Label : label
Features (1) : feature
Weights : None
Trained with tuner : No
Model size : 13257 kB
Number of records: 10000 Number of columns: 2 Number of columns by type: NUMERICAL_VECTOR_SEQUENCE: 1 (50%) CATEGORICAL: 1 (50%) Columns: NUMERICAL_VECTOR_SEQUENCE: 1 (50%) 1: "feature" NUMERICAL_VECTOR_SEQUENCE mean:0.00616592 min:-1.49985 max:1.49994 sd:0.865995 dims:2 min-vecs:0 max-vecs:4 dtype:DTYPE_FLOAT64 CATEGORICAL: 1 (50%) 0: "label" CATEGORICAL has-dict vocab-size:3 zero-ood-items most-frequent:"false" 5080 (50.8%) dtype:DTYPE_BOOL Terminology: nas: Number of non-available (i.e. missing) values. ood: Out of dictionary. manually-defined: Attribute whose type is manually defined by the user, i.e., the type was not automatically inferred. tokenized: The attribute value is obtained through tokenization. has-dict: The attribute is attached to a string dictionary e.g. a categorical attribute stored as a string. vocab-size: Number of unique values.
The following evaluation is computed on the validation or out-of-bag dataset.
Number of predictions (without weights): 10000 Number of predictions (with weights): 10000 Task: CLASSIFICATION Label: label Accuracy: 0.9936 CI95[W][0.992125 0.994854] LogLoss: : 0.0304112 ErrorRate: : 0.00639999 Default Accuracy: : 0.508 Default LogLoss: : 0.693019 Default ErrorRate: : 0.492 Confusion Table: truth\prediction false true false 5063 17 true 47 4873 Total: 10000
Variable importances measure the importance of an input feature for a model.
1. "feature" 1.000000
1. "feature" 300.000000
1. "feature" 19649.000000
1. "feature" 2054347.537343
Those variable importances are computed during training. More, and possibly more informative, variable importances are available when analyzing a model on a test dataset.
Only printing the first tree.
Tree #0: "feature" contains X with | X - [-0.073972, -0.13002] |² <= 1.13198 [s:0.474337 n:10000 np:5257 miss:1] ; val:"false" prob:[0.5169, 0.4831] ├─(pos)─ "feature" contains X with | X - [0.23399, -0.012537] |² <= 1.14279 [s:0.116865 n:5257 np:4686 miss:1] ; val:"true" prob:[0.0996766, 0.900323] | ├─(pos)─ "feature" contains X with | X - [0.046393, 0.021299] |² <= 0.925552 [s:0.11139 n:4686 np:4403 miss:1] ; val:"true" prob:[0.0352113, 0.964789] | | ├─(pos)─ val:"true" prob:[0, 1] | | └─(neg)─ "feature" contains X with X @ [2.6077, -1.6338] >= 1.79877 [s:0.0892376 n:283 np:236 miss:0] ; val:"false" prob:[0.583039, 0.416961] | | ├─(pos)─ "feature" contains X with | X - [-0.89351, 0.41496] |² <= 3.78143 [s:0.0793153 n:236 np:194 miss:1] ; val:"false" prob:[0.673729, 0.326271] | | | ├─(pos)─ "feature" contains X with | X - [-1.0493, -0.92765] |² <= 0.194547 [s:0.0745684 n:194 np:38 miss:1] ; val:"false" prob:[0.603093, 0.396907] | | | | ├─(pos)─ "feature" contains X with | X - [0.15594, -1.0044] |² <= 0.0241888 [s:0.117638 n:38 np:5 miss:1] ; val:"false" prob:[0.947368, 0.0526316] | | | | | ├─(pos)─ val:"false" prob:[0.6, 0.4] | | | | | └─(neg)─ val:"false" prob:[1, 0] | | | | └─(neg)─ "feature" contains X with X @ [0.63711, -1.3639] >= 1.4673 [s:0.0778821 n:156 np:99 miss:0] ; val:"false" prob:[0.519231, 0.480769] | | | | ├─(pos)─ "feature" contains X with X @ [-2.2051, 0.86293] >= 2.3505 [s:0.116979 n:99 np:31 miss:0] ; val:"false" prob:[0.666667, 0.333333] | | | | | ├─(pos)─ "feature" contains X with | X - [0.93742, -0.45437] |² <= 0.79547 [s:0.0617958 n:31 np:26 miss:1] ; val:"false" prob:[0.967742, 0.0322581] | | | | | | ├─(pos)─ val:"false" prob:[1, 0] | | | | | | └─(neg)─ val:"false" prob:[0.8, 0.2] | | | | | └─(neg)─ "feature" contains X with X @ [-1.1863, -0.37657] >= 0.371011 [s:0.254621 n:68 np:35 miss:0] ; val:"false" prob:[0.529412, 0.470588] | | | | | ├─(pos)─ "feature" contains X with X @ [2.0587, -1.4712] >= 3.88402 [s:0.290462 n:35 np:5 miss:0] ; val:"true" prob:[0.2, 0.8] | | | | | | ├─(pos)─ val:"false" prob:[1, 0] | | | | | | └─(neg)─ "feature" contains X with X @ [-0.97637, -1.6315] >= 2.05231 [s:0.132761 n:30 np:5 miss:0] ; val:"true" prob:[0.0666667, 0.933333] | | | | | | ├─(pos)─ val:"true" prob:[0.4, 0.6] | | | | | | └─(neg)─ val:"true" prob:[0, 1] | | | | | └─(neg)─ "feature" contains X with | X - [0.46516, -0.84671] |² <= 0.0262009 [s:0.253603 n:33 np:6 miss:1] ; val:"false" prob:[0.878788, 0.121212] | | | | | ├─(pos)─ val:"true" prob:[0.333333, 0.666667] | | | | | └─(neg)─ val:"false" prob:[1, 0] | | | | └─(neg)─ "feature" contains X with | X - [1.0294, -0.05871] |² <= 0.135678 [s:0.139053 n:57 np:18 miss:1] ; val:"true" prob:[0.263158, 0.736842] | | | | ├─(pos)─ "feature" contains X with | X - [-0.92035, -0.2734] |² <= 0.0374038 [s:0.358182 n:18 np:5 miss:1] ; val:"false" prob:[0.611111, 0.388889] | | | | | ├─(pos)─ val:"true" prob:[0, 1] | | | | | └─(neg)─ "feature" contains X with | X - [1.2494, 0.90085] |² <= 0.453873 [s:0.170472 n:13 np:5 miss:1] ; val:"false" prob:[0.846154, 0.153846] | | | | | ├─(pos)─ val:"false" prob:[0.6, 0.4] | | | | | └─(neg)─ val:"false" prob:[1, 0] | | | | └─(neg)─ "feature" contains X with | X - [0.87904, -0.54081] |² <= 0.0324876 [s:0.145801 n:39 np:11 miss:1] ; val:"true" prob:[0.102564, 0.897436] | | | | ├─(pos)─ "feature" contains X with X @ [-0.19497, 1.5814] >= 1.58562 [s:0.428026 n:11 np:6 miss:0] ; val:"true" prob:[0.363636, 0.636364] | | | | | ├─(pos)─ val:"true" prob:[0, 1] | | | | | └─(neg)─ val:"false" prob:[0.8, 0.2] | | | | └─(neg)─ val:"true" prob:[0, 1] | | | └─(neg)─ val:"false" prob:[1, 0] | | └─(neg)─ "feature" contains X with X @ [-0.10046, 2.3265] >= 2.95329 [s:0.220656 n:47 np:11 miss:0] ; val:"true" prob:[0.12766, 0.87234] | | ├─(pos)─ "feature" contains X with | X - [0.48423, 0.96221] |² <= 0.128651 [s:0.689009 n:11 np:6 miss:1] ; val:"false" prob:[0.545455, 0.454545] | | | ├─(pos)─ val:"false" prob:[1, 0] | | | └─(neg)─ val:"true" prob:[0, 1] | | └─(neg)─ val:"true" prob:[0, 1] | └─(neg)─ "feature" contains X with | X - [-0.83747, 0.27321] |² <= 0.959256 [s:0.113909 n:571 np:390 miss:1] ; val:"false" prob:[0.628722, 0.371278] | ├─(pos)─ "feature" contains X with X @ [-2.0488, -0.56637] >= 2.05562 [s:0.0916326 n:390 np:233 miss:0] ; val:"true" prob:[0.484615, 0.515385] | | ├─(pos)─ "feature" contains X with X @ [-0.18616, -0.088117] >= 0.252543 [s:0.130844 n:233 np:92 miss:0] ; val:"false" prob:[0.656652, 0.343348] | | | ├─(pos)─ "feature" contains X with | X - [1.4429, -0.35152] |² <= 3.80915 [s:0.0996367 n:92 np:40 miss:1] ; val:"true" prob:[0.358696, 0.641304] | | | | ├─(pos)─ "feature" contains X with | X - [-1.3635, 0.20028] |² <= 0.19175 [s:0.292586 n:40 np:20 miss:1] ; val:"false" prob:[0.6, 0.4] | | | | | ├─(pos)─ "feature" contains X with X @ [-1.6536, -1.8412] >= 3.44988 [s:0.0734147 n:20 np:15 miss:0] ; val:"false" prob:[0.95, 0.05] | | | | | | ├─(pos)─ val:"false" prob:[1, 0] | | | | | | └─(neg)─ val:"false" prob:[0.8, 0.2] | | | | | └─(neg)─ "feature" contains X with X @ [0.21358, -0.18956] >= 0.187504 [s:0.191258 n:20 np:14 miss:0] ; val:"true" prob:[0.25, 0.75] | | | | | ├─(pos)─ "feature" contains X with X @ [1.2553, 2.2469] >= 1.13286 [s:0.0786036 n:14 np:9 miss:0] ; val:"true" prob:[0.0714286, 0.928571] | | | | | | ├─(pos)─ val:"true" prob:[0, 1] | | | | | | └─(neg)─ val:"true" prob:[0.2, 0.8] | | | | | └─(neg)─ val:"false" prob:[0.666667, 0.333333] | | | | └─(neg)─ "feature" contains X with | X - [-1.1892, -0.49788] |² <= 0.088806 [s:0.234692 n:52 np:17 miss:1] ; val:"true" prob:[0.173077, 0.826923] | | | | ├─(pos)─ "feature" contains X with | X - [-1.1892, -0.49788] |² <= 0.0223469 [s:0.691416 n:17 np:8 miss:1] ; val:"false" prob:[0.529412, 0.470588] | | | | | ├─(pos)─ val:"true" prob:[0, 1] | | | | | └─(neg)─ val:"false" prob:[1, 0] | | | | └─(neg)─ val:"true" prob:[0, 1] | | | └─(neg)─ "feature" contains X with | X - [-0.91476, 0.31935] |² <= 0.00975047 [s:0.0820785 n:141 np:8 miss:1] ; val:"false" prob:[0.851064, 0.148936] | | | ├─(pos)─ val:"true" prob:[0.125, 0.875] | | | └─(neg)─ "feature" contains X with | X - [-0.67918, 0.58657] |² <= 0.0723841 [s:0.044536 n:133 np:6 miss:1] ; val:"false" prob:[0.894737, 0.105263] | | | ├─(pos)─ val:"true" prob:[0.333333, 0.666667] | | | └─(neg)─ "feature" contains X with | X - [-0.77116, -0.52501] |² <= 0.0114129 [s:0.0381457 n:127 np:5 miss:1] ; val:"false" prob:[0.92126, 0.0787402] | | | ├─(pos)─ val:"true" prob:[0.4, 0.6] | | | └─(neg)─ "feature" contains X with X @ [-1.8854, -1.1367] >= 2.09149 [s:0.0480761 n:122 np:90 miss:0] ; val:"false" prob:[0.942623, 0.057377] | | | ├─(pos)─ "feature" contains X with | X - [-0.91822, 1.101] |² <= 0.0467917 [s:0.0332469 n:90 np:5 miss:1] ; val:"false" prob:[0.988889, 0.0111111] | | | | ├─(pos)─ val:"false" prob:[0.8, 0.2] | | | | └─(neg)─ val:"false" prob:[1, 0] | | | └─(neg)─ "feature" contains X with X @ [-1.7734, -0.97891] >= 1.87687 [s:0.2205 n:32 np:7 miss:0] ; val:"false" prob:[0.8125, 0.1875] | | | ├─(pos)─ val:"true" prob:[0.285714, 0.714286] | | | └─(neg)─ "feature" contains X with X @ [-0.055621, 0.0095099] >= 0.0746306 [s:0.0678637 n:25 np:5 miss:0] ; val:"false" prob:[0.96, 0.04] | | | ├─(pos)─ val:"false" prob:[0.8, 0.2] | | | └─(neg)─ val:"false" prob:[1, 0] | | └─(neg)─ "feature" contains X with X @ [-2.0516, 0.99254] >= 2.24698 [s:0.1663 n:157 np:52 miss:0] ; val:"true" prob:[0.229299, 0.770701] | | ├─(pos)─ "feature" contains X with | X - [-1.3386, 1.4541] |² <= 0.73135 [s:0.325203 n:52 np:22 miss:1] ; val:"false" prob:[0.576923, 0.423077] | | | ├─(pos)─ "feature" contains X with X @ [-2.1489, -0.094533] >= 1.86235 [s:0.24535 n:22 np:17 miss:0] ; val:"true" prob:[0.136364, 0.863636] | | | | ├─(pos)─ val:"true" prob:[0, 1] | | | | └─(neg)─ val:"false" prob:[0.6, 0.4] | | | └─(neg)─ "feature" contains X with | X - [-0.98063, 0.40346] |² <= 0.034739 [s:0.0748818 n:30 np:15 miss:1] ; val:"false" prob:[0.9, 0.1] | | | ├─(pos)─ "feature" contains X with | X - [-1.025, 0.4303] |² <= 0.00414973 [s:0.223144 n:15 np:9 miss:1] ; val:"false" prob:[0.8, 0.2] | | | | ├─(pos)─ val:"false" prob:[1, 0] | | | | └─(neg)─ val:"false" prob:[0.5, 0.5] | | | └─(neg)─ val:"false" prob:[1, 0] | | └─(neg)─ "feature" contains X with | X - [-0.84759, 0.17248] |² <= 0.582581 [s:0.0840044 n:105 np:92 miss:1] ; val:"true" prob:[0.0571429, 0.942857] | | ├─(pos)─ "feature" contains X with X @ [-0.12316, 0.011823] >= 0.118263 [s:0.0327644 n:92 np:5 miss:0] ; val:"true" prob:[0.0108696, 0.98913] | | | ├─(pos)─ val:"true" prob:[0.2, 0.8] | | | └─(neg)─ val:"true" prob:[0, 1] | | └─(neg)─ "feature" contains X with | X - [-0.69287, -0.60329] |² <= 0.0086635 [s:0.458327 n:13 np:7 miss:1] ; val:"true" prob:[0.384615, 0.615385] | | ├─(pos)─ val:"true" prob:[0, 1] | | └─(neg)─ val:"false" prob:[0.833333, 0.166667] | └─(neg)─ "feature" contains X with | X - [-0.5638, -0.88036] |² <= 0.0402295 [s:0.0537754 n:181 np:78 miss:1] ; val:"false" prob:[0.939227, 0.0607735] | ├─(pos)─ "feature" contains X with X @ [-1.5154, -2.4025] >= 2.82972 [s:0.174021 n:78 np:72 miss:0] ; val:"false" prob:[0.858974, 0.141026] | | ├─(pos)─ "feature" contains X with | X - [-1.3237, -1.1666] |² <= 0.280085 [s:0.214653 n:72 np:6 miss:1] ; val:"false" prob:[0.930556, 0.0694444] | | | ├─(pos)─ val:"true" prob:[0.166667, 0.833333] | | | └─(neg)─ val:"false" prob:[1, 0] | | └─(neg)─ val:"true" prob:[0, 1] | └─(neg)─ val:"false" prob:[1, 0] └─(neg)─ "feature" contains X with | X - [0.44434, 0.9401] |² <= 0.654926 [s:0.027429 n:4743 np:1042 miss:1] ; val:"false" prob:[0.979338, 0.020662] ├─(pos)─ "feature" contains X with X @ [0.42567, 1.8565] >= 1.83566 [s:0.0636289 n:1042 np:883 miss:0] ; val:"false" prob:[0.909789, 0.0902111] | ├─(pos)─ "feature" contains X with | X - [-0.27788, 1.2018] |² <= 0.351519 [s:0.026306 n:883 np:357 miss:1] ; val:"false" prob:[0.961495, 0.0385051] | | ├─(pos)─ "feature" contains X with | X - [0.74848, 0.6849] |² <= 0.472482 [s:0.0761831 n:357 np:162 miss:1] ; val:"false" prob:[0.910364, 0.0896359] | | | ├─(pos)─ "feature" contains X with | X - [0.33268, 0.8648] |² <= 0.0237821 [s:0.0884767 n:162 np:13 miss:1] ; val:"false" prob:[0.802469, 0.197531] | | | | ├─(pos)─ "feature" contains X with | X - [1.0219, 0.78696] |² <= 0.319171 [s:0.170472 n:13 np:5 miss:1] ; val:"true" prob:[0.153846, 0.846154] | | | | | ├─(pos)─ val:"true" prob:[0.4, 0.6] | | | | | └─(neg)─ val:"true" prob:[0, 1] | | | | └─(neg)─ "feature" contains X with | X - [-0.057619, 0.97193] |² <= 0.0473645 [s:0.0832656 n:149 np:19 miss:1] ; val:"false" prob:[0.85906, 0.14094] | | | | ├─(pos)─ "feature" contains X with | X - [-0.093403, 1.0975] |² <= 0.0167449 [s:0.49947 n:19 np:7 miss:1] ; val:"true" prob:[0.421053, 0.578947] | | | | | ├─(pos)─ val:"false" prob:[1, 0] | | | | | └─(neg)─ "feature" contains X with | X - [-0.74371, 1.055] |² <= 0.604696 [s:0.0783349 n:12 np:7 miss:1] ; val:"true" prob:[0.0833333, 0.916667] | | | | | ├─(pos)─ val:"true" prob:[0, 1] | | | | | └─(neg)─ val:"true" prob:[0.2, 0.8] | | | | └─(neg)─ "feature" contains X with X @ [-0.057704, 2.4017] >= 3.23595 [s:0.0917971 n:130 np:43 miss:0] ; val:"false" prob:[0.923077, 0.0769231] | | | | ├─(pos)─ "feature" contains X with | X - [0.79814, 0.54855] |² <= 0.0079685 [s:0.247604 n:43 np:6 miss:1] ; val:"false" prob:[0.767442, 0.232558] | | | | | ├─(pos)─ val:"true" prob:[0, 1] | | | | | └─(neg)─ "feature" contains X with X @ [-0.50249, -0.070775] >= -0.221604 [s:0.131325 n:37 np:32 miss:1] ; val:"false" prob:[0.891892, 0.108108] | | | | | ├─(pos)─ "feature" contains X with | X - [0.65113, 0.83079] |² <= 0.0219412 [s:0.0608729 n:32 np:5 miss:1] ; val:"false" prob:[0.96875, 0.03125] | | | | | | ├─(pos)─ val:"false" prob:[0.8, 0.2] | | | | | | └─(neg)─ val:"false" prob:[1, 0] | | | | | └─(neg)─ val:"true" prob:[0.4, 0.6] | | | | └─(neg)─ val:"false" prob:[1, 0] | | | └─(neg)─ val:"false" prob:[1, 0] | | └─(neg)─ "feature" contains X with | X - [0.83084, 1.1891] |² <= 0.284236 [s:0.0140663 n:526 np:512 miss:1] ; val:"false" prob:[0.996198, 0.00380228] | | ├─(pos)─ val:"false" prob:[1, 0] | | └─(neg)─ "feature" contains X with | X - [0.28383, 0.97257] |² <= 0.255275 [s:0.169755 n:14 np:9 miss:1] ; val:"false" prob:[0.857143, 0.142857] | | ├─(pos)─ val:"false" prob:[1, 0] | | └─(neg)─ val:"false" prob:[0.6, 0.4] | └─(neg)─ "feature" contains X with X @ [2.1172, 1.7939] >= 2.75332 [s:0.394315 n:159 np:94 miss:0] ; val:"false" prob:[0.622642, 0.377358] | ├─(pos)─ "feature" contains X with | X - [0.3891, 0.89861] |² <= 0.0924216 [s:0.135347 n:94 np:6 miss:1] ; val:"false" prob:[0.957447, 0.0425532] | | ├─(pos)─ val:"true" prob:[0.333333, 0.666667] | | └─(neg)─ val:"false" prob:[1, 0] | └─(neg)─ "feature" contains X with | X - [-0.25747, 1.0062] |² <= 0.00479214 [s:0.402161 n:65 np:9 miss:1] ; val:"true" prob:[0.138462, 0.861538] | ├─(pos)─ val:"false" prob:[1, 0] | └─(neg)─ val:"true" prob:[0, 1] └─(neg)─ "feature" contains X with | X - [1.4578, 0.72475] |² <= 0.735278 [s:0.00250899 n:3701 np:365 miss:1] ; val:"false" prob:[0.998919, 0.00108079] ├─(pos)─ "feature" contains X with | X - [1.025, 0.18544] |² <= 0.0345168 [s:0.0244523 n:365 np:41 miss:1] ; val:"false" prob:[0.989041, 0.0109589] | ├─(pos)─ "feature" contains X with | X - [1.1412, 0.27389] |² <= 0.0336243 [s:0.133394 n:41 np:29 miss:1] ; val:"false" prob:[0.902439, 0.097561] | | ├─(pos)─ val:"false" prob:[1, 0] | | └─(neg)─ "feature" contains X with | X - [1.0641, 0.041744] |² <= 0.00282113 [s:0.428013 n:12 np:7 miss:1] ; val:"false" prob:[0.666667, 0.333333] | | ├─(pos)─ val:"false" prob:[1, 0] | | └─(neg)─ val:"true" prob:[0.2, 0.8] | └─(neg)─ val:"false" prob:[1, 0] └─(neg)─ val:"false" prob:[1, 0]
Finally, we can evaluate the model.
test_ds = make_toy_ds(num_examples=1000)
model.evaluate(test_ds)
Evaluation of classification models
- Accuracy
- The simplest metric. It's the percentage of predictions that are correct (matching the ground truth).
Example: If a model correctly identifies 90 out of 100 images as cat or dog, the accuracy is 90%. - Confusion Matrix
- A table that shows the counts of:
- True Positives (TP): Model correctly predicted positive.
- True Negatives (TN): Model correctly predicted negative.
- False Positives (FP): Model incorrectly predicted positive (a "false alarm").
- False Negatives (FN): Model incorrectly predicted negative (a "miss").
- Threshold
- YDF classification models predict a probability for each class. A threshold determines the cutoff for classifying something as positive or negative.
Example: If the threshold is 0.5, any prediction above 0.5 might be classified as "spam," and anything below as "not spam." - ROC Curve (Receiver Operating Characteristic Curve)
- A graph that plots the True Positive Rate (TPR) against the False Positive Rate (FPR) at various thresholds.
- TPR (Sensitivity or Recall): TP / (TP + FN) - How many of the actual positives did the model catch?
- FPR: FP / (FP + TN) - How many negatives were incorrectly classified as positives?
Interpretation: A good model has an ROC curve that hugs the top-left corner (high TPR, low FPR). - AUC (Area Under the ROC Curve)
- A single number that summarizes the overall performance shown by the ROC curve. The AUC is a more stable metric than the accuracy. Multi-class classification models evaluate one class against all other classes.
Interpretation: Ranges from 0 to 1. A perfect model has an AUC of 1, while a random model has an AUC of 0.5. Higher is better. - Precision-Recall Curve
- A graph that plots Precision against Recall at various thresholds.
- Precision: TP / (TP + FP) - Out of all the predictions the model labeled as positive, how many were actually positive?
- Recall (same as TPR): TP / (TP + FN) - Out of all the actual positive cases, how many did the model correctly identify?
Interpretation: A good model has a curve that stays high (both high precision and high recall). It is especially useful when dealing with imbalanced datasets (e.g., when one class is much rarer than the other). - PR-AUC (Area Under the Precision-Recall Curve)
- Similar to AUC, but for the Precision-Recall curve. A single number summarizing performance. Multi-class classification models evaluate one class against all other classes. Higher is better.
- Threshold / Accuracy Curve
- A graph that shows how the model's accuracy changes as you vary the classification threshold.
- Threshold / Volume Curve
- A graph showing how the number of data points classified as positive changes as you vary the threshold.
Label \ Pred | false | true |
---|---|---|
false | 492 | 5 |
true | 1 | 502 |
Part 2: LLM embedding + PCA + Vector Sequence¶
Now that we understand vector sequences, let's build a text classifier using GPT2's first hidden layer and a Random Forest.
Our model pipeline:
- GPT2 Tokenizer: Converts text to tokens e.g., "the cat is red" → [362, 82, 673, 6543].
- GPT2 Token Embedding: Grabs the embedding of each token. Output shape: <num tokens, 768>.
- GPT2 First Hidden Layer: Applies the first layer of GPT2 (Attention and other NN operations). Output shape: <num tokens, 768>.
- PCA: Reduces dimensionality of the embedding. Output shape: <num tokens, 100>.
- Random Forest: Train a Random Forest to classify text from the PCA-transformed embeddings.
GPT2 model¶
We load GPT2 tokenizer and weights.
gpt2_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
gpt2_model = GPT2Model.from_pretrained("gpt2", output_hidden_states=True)
The tokenizer encodes a text into a list of token indexes.
tokens = gpt2_tokenizer("This is a good movie", return_tensors="pt")
print(tokens["input_ids"])
tensor([[1212, 318, 257, 922, 3807]])
Then, we apply the GPT2 model on the tokens and extract the output of layer #0. This computes the token embeddings and apply one layer of attention.
selected_hidden_layer = 0
gpt2_model(**tokens).hidden_states[selected_hidden_layer]
tensor([[[ 0.0065, -0.2930, 0.0762, ..., 0.0184, -0.0275, 0.1638], [ 0.0142, -0.0437, -0.0393, ..., 0.1487, -0.0278, -0.0255], [-0.0464, -0.0791, 0.1016, ..., 0.0623, 0.0928, -0.0598], [-0.0841, -0.1244, 0.1423, ..., -0.1435, -0.0718, -0.1183], [ 0.0331, -0.0645, 0.3507, ..., -0.0210, 0.0279, 0.1440]]], grad_fn=<AddBackward0>)
Let's group those two steps in a function that returns a numpy array.
def text_to_embedding(text: str) -> np.ndarray:
tokens = gpt2_tokenizer(text, return_tensors="pt")
return (
gpt2_model(**tokens)
.hidden_states[selected_hidden_layer]
.detach()
.numpy()[0]
)
text_to_embedding("This is a good movie")
array([[ 0.00649832, -0.29302013, 0.07615747, ..., 0.01843522, -0.02754061, 0.16376127], [ 0.01423593, -0.0437407 , -0.0392998 , ..., 0.14866675, -0.02783391, -0.02553328], [-0.04641282, -0.07912885, 0.10156769, ..., 0.06225622, 0.09284618, -0.05983091], [-0.08413801, -0.12438498, 0.14228812, ..., -0.14347112, -0.07182924, -0.1183255 ], [ 0.03311015, -0.06451828, 0.35070336, ..., -0.02101075, 0.0278743 , 0.14398581]], dtype=float32)
Load the dataset¶
AG News is a text classification dataset, where the task is to predict an article's category based on its content. Let's load it.
def ag_news_dataset(split: str):
class_mapping = {
0: "World",
1: "Sports",
2: "Business",
3: "Sci/Tech",
}
for example in load_dataset("ag_news")[split]:
yield {
"text": example["text"],
"label": class_mapping[example["label"]],
}
# Print the first 3 training examples
for example_idx, example in enumerate(islice(ag_news_dataset("train"), 3)):
print(f"==========\nExample #{example_idx}\n----------")
print(example)
========== Example #0 ---------- {'text': "Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall Street's dwindling\\band of ultra-cynics, are seeing green again.", 'label': 'Business'} ========== Example #1 ---------- {'text': 'Carlyle Looks Toward Commercial Aerospace (Reuters) Reuters - Private investment firm Carlyle Group,\\which has a reputation for making well-timed and occasionally\\controversial plays in the defense industry, has quietly placed\\its bets on another part of the market.', 'label': 'Business'} ========== Example #2 ---------- {'text': "Oil and Economy Cloud Stocks' Outlook (Reuters) Reuters - Soaring crude prices plus worries\\about the economy and the outlook for earnings are expected to\\hang over the stock market next week during the depth of the\\summer doldrums.", 'label': 'Business'}
We load more of the dataset.
num_examples = 1000 # Only load 1k example for the example.
labels = []
embeddings = []
for example in tqdm(
islice(ag_news_dataset("train"), num_examples), total=num_examples
):
embeddings.append(text_to_embedding(example["text"]))
labels.append(example["label"])
# raw_dataset = {"label": np.array(labels), "embedding": embeddings}
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:41<00:00, 24.00it/s]
Each example has a different number of tokens i.e. a different shape for the embedding matrix.
print("First example embedding size:", embeddings[0].shape)
print("Second example embedding size:", embeddings[1].shape)
First example embedding size: (37, 768) Second example embedding size: (55, 768)
Compress the embedding¶
Our dataset is small, so we can speed up training by using PCA to reduce the 768-dimensional embeddings to 50 dimensions without losing much accuracy. Let's do that.
# Collect all the embeddings into a single matrix.
combined_embeddings = np.concatenate(embeddings, axis=0)
# Normalize the embedding (this is necessary for PCA).
normalized_combined_embeddings = StandardScaler().fit_transform(
combined_embeddings
)
# Learn the compressed representation.
pca = PCA(n_components=50)
_ = pca.fit(normalized_combined_embeddings)
We compress our actual embedding values.
reduced_embeddings = [pca.transform(e) for e in embeddings]
print("First example embedding size:", reduced_embeddings[0].shape)
print("Second example embedding size:", reduced_embeddings[1].shape)
First example embedding size: (37, 50) Second example embedding size: (55, 50)
Finally, we assemble the data into a dictionary.
dataset = {"label": np.array(labels), "reduced_embeddings": reduced_embeddings}
Train model¶
We can now train our model.
model = ydf.RandomForestLearner(label="label").train(dataset, verbose=2)
Train model on 1000 examples Model trained in 0:00:49.579565
Since the model is a Random Forest, we can look at the model self-evaluation (a.k.a., Out-of-bag evaluation) to estimate the model quality.
model.self_evaluation()
Evaluation of classification models
- Accuracy
- The simplest metric. It's the percentage of predictions that are correct (matching the ground truth).
Example: If a model correctly identifies 90 out of 100 images as cat or dog, the accuracy is 90%. - Confusion Matrix
- A table that shows the counts of:
- True Positives (TP): Model correctly predicted positive.
- True Negatives (TN): Model correctly predicted negative.
- False Positives (FP): Model incorrectly predicted positive (a "false alarm").
- False Negatives (FN): Model incorrectly predicted negative (a "miss").
- Threshold
- YDF classification models predict a probability for each class. A threshold determines the cutoff for classifying something as positive or negative.
Example: If the threshold is 0.5, any prediction above 0.5 might be classified as "spam," and anything below as "not spam." - ROC Curve (Receiver Operating Characteristic Curve)
- A graph that plots the True Positive Rate (TPR) against the False Positive Rate (FPR) at various thresholds.
- TPR (Sensitivity or Recall): TP / (TP + FN) - How many of the actual positives did the model catch?
- FPR: FP / (FP + TN) - How many negatives were incorrectly classified as positives?
Interpretation: A good model has an ROC curve that hugs the top-left corner (high TPR, low FPR). - AUC (Area Under the ROC Curve)
- A single number that summarizes the overall performance shown by the ROC curve. The AUC is a more stable metric than the accuracy. Multi-class classification models evaluate one class against all other classes.
Interpretation: Ranges from 0 to 1. A perfect model has an AUC of 1, while a random model has an AUC of 0.5. Higher is better. - Precision-Recall Curve
- A graph that plots Precision against Recall at various thresholds.
- Precision: TP / (TP + FP) - Out of all the predictions the model labeled as positive, how many were actually positive?
- Recall (same as TPR): TP / (TP + FN) - Out of all the actual positive cases, how many did the model correctly identify?
Interpretation: A good model has a curve that stays high (both high precision and high recall). It is especially useful when dealing with imbalanced datasets (e.g., when one class is much rarer than the other). - PR-AUC (Area Under the Precision-Recall Curve)
- Similar to AUC, but for the Precision-Recall curve. A single number summarizing performance. Multi-class classification models evaluate one class against all other classes. Higher is better.
- Threshold / Accuracy Curve
- A graph that shows how the model's accuracy changes as you vary the classification threshold.
- Threshold / Volume Curve
- A graph showing how the number of data points classified as positive changes as you vary the threshold.
Label \ Pred | Business | Sci/Tech | Sports | World |
---|---|---|---|---|
Business | 81 | 4 | 0 | 4 |
Sci/Tech | 87 | 461 | 36 | 43 |
Sports | 0 | 0 | 100 | 5 |
World | 6 | 7 | 6 | 160 |