pip install ydf
What is a custom loss?¶
In gradient boosted trees, the loss is a function that takes a label value and a prediction, and returns the "amount of error" of this prediction. The model is trained to minimize the average loss over all the training examples. YDF implements various common losses. You can configure them with the "loss" parameter. You can see the list of available losses here. If you don't specify the loss, it is selected automatically according to the model task. For instance, if the task is regression, the loss is set to mean-squared error by default.
If YDF does not support a loss you needs, you can define it manually. This is called a "custom loss".
In this introduction tutorial, we will create a custom Regression Loss called Mean Squared Logarithmic Error.
Custom Losses in YDF¶
In YDF, a custom loss consists of four parts:
- Initial prediction: The initial prediction of the model, e.g. the average of the labels.
- Gradient and Hessian: A function that computes the gradient and the diagonal of the hessian of the loss given the label and the prediction of the model before the activation function (a.k.a. linkage function).
- Loss: A function that measures the quality of the current solution. While theory might dictate that the gradient and hessian are actually the gradient and hessian of the loss function, approximations do very well in practice.
- Activation: A function applied to the predictions to transform them to the correct space (e.g. probabilities for classification problems)
Training Gradient Boosted Trees with custom loss¶
We start by setting up a regression dataset.
# Load libraries
import ydf # Yggdrasil Decision Forests
import pandas as pd # We use Pandas to load small datasets
import numpy as np # We use numpy for numerical operation
import numpy.typing as npty
from typing import Tuple
# Download a regression dataset and load it as a Pandas DataFrame.
ds_path = "https://raw.githubusercontent.com/google/yggdrasil-decision-forests/main/yggdrasil_decision_forests/test_data/dataset"
all_ds = pd.read_csv(f"{ds_path}/abalone.csv")
# Randomly split the dataset into a training (70%) and testing (30%) dataset
all_ds = all_ds.sample(frac=1)
split_idx = len(all_ds) * 7 // 10
train_ds = all_ds.iloc[:split_idx]
test_ds = all_ds.iloc[split_idx:]
# Print the first 5 training examples
train_ds.head(5)
Type | LongestShell | Diameter | Height | WholeWeight | ShuckedWeight | VisceraWeight | ShellWeight | Rings | |
---|---|---|---|---|---|---|---|---|---|
1681 | F | 0.620 | 0.540 | 0.165 | 1.1390 | 0.4995 | 0.2435 | 0.3570 | 11 |
1168 | M | 0.620 | 0.450 | 0.200 | 0.8580 | 0.4285 | 0.1525 | 0.2405 | 8 |
484 | M | 0.630 | 0.480 | 0.145 | 1.0115 | 0.4235 | 0.2370 | 0.3050 | 12 |
1594 | I | 0.525 | 0.400 | 0.140 | 0.6540 | 0.3050 | 0.1600 | 0.1690 | 7 |
1192 | M | 0.700 | 0.565 | 0.180 | 1.7510 | 0.8950 | 0.3355 | 0.4460 | 9 |
Mean Squared Logarithmic Error¶
We use Mean Squared Logarithmic Error (MSLE) loss for this tutorial. The MSLE is calculated as
MSLE = $\frac{1}{n} \sum_{i=1}^n (\log(p_i + 1) - \log(a_i+1))^2$,
where $n$ is the total number of observations, $p_i$ and $a_i$ are the prediction and label of example $i$, respectively, and $\log$ denotes the natural logarithm.
The gradient of the MSLE loss with respect to the prediction $p_i$ is
$\frac{1}{n} \cdot \frac{2(\log(p_i + 1) - \log(a_i+1))}{p_i + 1}$
The hessian of the MSLE loss is a matrix. For simplicity and performance reasons, YDF only uses the diagonal of the hessian. The $i$th element of the diagonal is
$\frac{1}{n} \cdot \frac{2(1 - \log(p_i + 1) + \log(a_i+1))}{(p_i + 1)^2}$
# If predictions are close to -1, numerical instabilities will distort the
# results. The predictions are therefore capped slightly above -1.
PREDICTION_MINIMUM = -1 + 1e-6
def loss_msle(
labels: npty.NDArray[np.float32],
predictions: npty.NDArray[np.float32],
weights: npty.NDArray[np.float32],
) -> np.float32:
clipped_pred = np.maximum(PREDICTION_MINIMUM, predictions)
return np.sum((np.log1p(clipped_pred) - np.log1p(labels))**2) / len(labels)
def initial_predictions_msle(
labels: npty.NDArray[np.float32], _: npty.NDArray[np.float32]
) -> npty.NDArray[np.float32]:
return np.exp(np.mean(np.log1p(labels))) - 1
def grad_msle(
labels: npty.NDArray[np.float32], predictions: npty.NDArray[np.float32]
) -> npty.NDArray[np.float32]:
gradient = (2/ len(labels))*(np.log1p(predictions) - np.log1p(labels)) / (predictions + 1)
return gradient
def hessian_msle(
labels: npty.NDArray[np.float32], predictions: npty.NDArray[np.float32]
) -> npty.NDArray[np.float32]:
hessian = (2/ len(labels))*(1 - np.log1p(predictions) + np.log1p(labels)) / (predictions + 1)**2
return hessian
def gradient_and_hessian_msle(
labels: npty.NDArray[np.float32], predictions: npty.NDArray[np.float32]
) -> Tuple[npty.NDArray[np.float32], npty.NDArray[np.float32]]:
clipped_pred = np.maximum(PREDICTION_MINIMUM, predictions)
return [grad_msle(labels, clipped_pred), hessian_msle(labels, clipped_pred)]
# Construct the loss object.
msle_custom_loss = ydf.RegressionLoss(
initial_predictions=initial_predictions_msle,
gradient_and_hessian=gradient_and_hessian_msle,
loss=loss_msle,
activation=ydf.Activation.IDENTITY,
)
The model is trained as usual with the loss object as a hyperparameter.
model = ydf.GradientBoostedTreesLearner(label="Rings", task=ydf.Task.REGRESSION, loss=msle_custom_loss).train(train_ds)
Train model on 2923 examples Using a custom loss. Note when using custom losses, hyperparameter `apply_link_function` is ignored. Use the losses' activation function instead. Model trained in 0:00:01.596486
The model description shows the evolution of training loss and validation loss.
model.describe()
Task : REGRESSION
Label : Rings
Features (8) : Type LongestShell Diameter Height WholeWeight ShuckedWeight VisceraWeight ShellWeight
Weights : None
Trained with tuner : No
Model size : 2263 kB
Number of records: 2923 Number of columns: 9 Number of columns by type: NUMERICAL: 8 (88.8889%) CATEGORICAL: 1 (11.1111%) Columns: NUMERICAL: 8 (88.8889%) 0: "Rings" NUMERICAL mean:9.97366 min:1 max:29 sd:3.26558 dtype:DTYPE_INT64 2: "LongestShell" NUMERICAL mean:0.524798 min:0.075 max:0.815 sd:0.119372 dtype:DTYPE_FLOAT64 3: "Diameter" NUMERICAL mean:0.408751 min:0.055 max:0.65 sd:0.0987606 dtype:DTYPE_FLOAT64 4: "Height" NUMERICAL mean:0.139512 min:0.01 max:0.515 sd:0.0386353 dtype:DTYPE_FLOAT64 5: "WholeWeight" NUMERICAL mean:0.830059 min:0.002 max:2.657 sd:0.488709 dtype:DTYPE_FLOAT64 6: "ShuckedWeight" NUMERICAL mean:0.360019 min:0.001 max:1.488 sd:0.221456 dtype:DTYPE_FLOAT64 7: "VisceraWeight" NUMERICAL mean:0.180917 min:0.0005 max:0.6415 sd:0.108618 dtype:DTYPE_FLOAT64 8: "ShellWeight" NUMERICAL mean:0.238848 min:0.0015 max:1.005 sd:0.138498 dtype:DTYPE_FLOAT64 CATEGORICAL: 1 (11.1111%) 1: "Type" CATEGORICAL has-dict vocab-size:4 zero-ood-items most-frequent:"M" 1087 (37.1878%) dtype:DTYPE_BYTES Terminology: nas: Number of non-available (i.e. missing) values. ood: Out of dictionary. manually-defined: Attribute whose type is manually defined by the user, i.e., the type was not automatically inferred. tokenized: The attribute value is obtained through tokenization. has-dict: The attribute is attached to a string dictionary e.g. a categorical attribute stored as a string. vocab-size: Number of unique values.
The following evaluation is computed on the validation or out-of-bag dataset.
Variable importances measure the importance of an input feature for a model.
1. "ShellWeight" 0.532153 ################ 2. "WholeWeight" 0.351222 ####### 3. "ShuckedWeight" 0.244100 ## 4. "LongestShell" 0.235793 ## 5. "Height" 0.223162 # 6. "VisceraWeight" 0.210488 # 7. "Diameter" 0.206625 8. "Type" 0.185702
1. "WholeWeight" 145.000000 ################ 2. "ShellWeight" 87.000000 ######### 3. "Height" 20.000000 # 4. "VisceraWeight" 16.000000 # 5. "LongestShell" 12.000000 6. "Diameter" 11.000000 7. "ShuckedWeight" 5.000000
1. "ShuckedWeight" 1160.000000 ################ 2. "ShellWeight" 919.000000 ############ 3. "WholeWeight" 660.000000 ######## 4. "Height" 437.000000 ##### 5. "LongestShell" 424.000000 ##### 6. "Diameter" 406.000000 ##### 7. "VisceraWeight" 322.000000 #### 8. "Type" 31.000000
1. "ShellWeight" 0.000008 ################ 2. "WholeWeight" 0.000005 ######### 3. "VisceraWeight" 0.000002 ### 4. "LongestShell" 0.000002 ### 5. "ShuckedWeight" 0.000002 ### 6. "Diameter" 0.000001 ## 7. "Height" 0.000000 8. "Type" 0.000000
Those variable importances are computed during training. More, and possibly more informative, variable importances are available when analyzing a model on a test dataset.
Only printing the first tree.
Tree #0: "ShellWeight">=0.15375 [s:1.50679e-10 n:2648 np:1817 miss:1] ; pred:-6.08611e-07 ├─(pos)─ "ShellWeight">=0.28975 [s:3.06419e-11 n:1817 np:904 miss:0] ; pred:1.50836 | ├─(pos)─ "ShellWeight">=0.40975 [s:2.08793e-11 n:904 np:293 miss:0] ; pred:1.25334 | | ├─(pos)─ "ShuckedWeight">=0.63925 [s:3.76591e-11 n:293 np:173 miss:0] ; pred:0.599562 | | | ├─(pos)─ "ShellWeight">=0.57775 [s:3.09601e-11 n:173 np:33 miss:0] ; pred:0.265588 | | | | ├─(pos)─ pred:0.0884813 | | | | └─(neg)─ pred:0.177107 | | | └─(neg)─ "ShellWeight">=0.5075 [s:3.65286e-11 n:120 np:29 miss:0] ; pred:0.333974 | | | ├─(pos)─ pred:0.111759 | | | └─(neg)─ pred:0.222215 | | └─(neg)─ "ShuckedWeight">=0.43875 [s:4.23098e-11 n:611 np:455 miss:0] ; pred:0.653777 | | ├─(pos)─ "ShellWeight">=0.38925 [s:4.69731e-12 n:455 np:69 miss:0] ; pred:0.313559 | | | ├─(pos)─ pred:0.0829214 | | | └─(neg)─ pred:0.230638 | | └─(neg)─ "WholeWeight">=1.09075 [s:1.84243e-11 n:156 np:31 miss:0] ; pred:0.340218 | | ├─(pos)─ pred:0.0943271 | | └─(neg)─ pred:0.245891 | └─(neg)─ "ShuckedWeight">=0.27975 [s:9.24118e-12 n:913 np:621 miss:1] ; pred:0.255019 | ├─(pos)─ "ShellWeight">=0.23475 [s:1.3242e-11 n:621 np:353 miss:1] ; pred:0.0440083 | | ├─(pos)─ "ShuckedWeight">=0.39825 [s:1.79254e-11 n:353 np:190 miss:0] ; pred:0.136942 | | | ├─(pos)─ pred:-0.000799999 | | | └─(neg)─ pred:0.137742 | | └─(neg)─ "ShellWeight">=0.18475 [s:9.53197e-12 n:268 np:192 miss:1] ; pred:-0.0929341 | | ├─(pos)─ pred:-0.0292848 | | └─(neg)─ pred:-0.0636493 | └─(neg)─ "ShellWeight">=0.18975 [s:4.18104e-11 n:292 np:132 miss:1] ; pred:0.211011 | ├─(pos)─ "Type" is in [BITMAP] {M, F} [s:4.07772e-11 n:132 np:102 miss:0] ; pred:0.189359 | | ├─(pos)─ pred:0.181646 | | └─(neg)─ pred:0.00771209 | └─(neg)─ "ShuckedWeight">=0.21575 [s:2.21914e-11 n:160 np:97 miss:1] ; pred:0.0216526 | ├─(pos)─ pred:-0.0236986 | └─(neg)─ pred:0.0453512 └─(neg)─ "Diameter">=0.2225 [s:1.05556e-10 n:831 np:697 miss:1] ; pred:-1.50836 ├─(pos)─ "Type" is in [BITMAP] {<OOD>, M, F} [s:2.913e-11 n:697 np:242 miss:1] ; pred:-0.951147 | ├─(pos)─ "ShuckedWeight">=0.233 [s:1.74077e-11 n:242 np:47 miss:1] ; pred:-0.151146 | | ├─(pos)─ "VisceraWeight">=0.1525 [s:9.15881e-12 n:47 np:5 miss:1] ; pred:-0.0692974 | | | ├─(pos)─ pred:-0.00298646 | | | └─(neg)─ pred:-0.0663109 | | └─(neg)─ "Height">=0.1025 [s:1.99501e-11 n:195 np:117 miss:1] ; pred:-0.0818482 | | ├─(pos)─ pred:-0.00643989 | | └─(neg)─ pred:-0.0754083 | └─(neg)─ "ShellWeight">=0.112 [s:2.31179e-11 n:455 np:158 miss:1] ; pred:-0.800002 | ├─(pos)─ "Height">=0.1325 [s:1.79824e-11 n:158 np:15 miss:1] ; pred:-0.173648 | | ├─(pos)─ pred:0.00315428 | | └─(neg)─ pred:-0.176802 | └─(neg)─ "ShellWeight">=0.06875 [s:9.09147e-12 n:297 np:177 miss:1] ; pred:-0.626354 | ├─(pos)─ pred:-0.329338 | └─(neg)─ pred:-0.297016 └─(neg)─ "ShellWeight">=0.02175 [s:7.89323e-11 n:134 np:78 miss:1] ; pred:-0.557212 ├─(pos)─ "LongestShell">=0.2525 [s:8.02273e-12 n:78 np:70 miss:1] ; pred:-0.265629 | ├─(pos)─ "VisceraWeight">=0.01875 [s:6.70755e-12 n:70 np:58 miss:1] ; pred:-0.231683 | | ├─(pos)─ pred:-0.198798 | | └─(neg)─ pred:-0.0328844 | └─(neg)─ pred:-0.0339468 └─(neg)─ "WholeWeight">=0.0165 [s:7.82022e-11 n:56 np:51 miss:1] ; pred:-0.291582 ├─(pos)─ "VisceraWeight">=0.01025 [s:7.9962e-12 n:51 np:21 miss:1] ; pred:-0.251427 | ├─(pos)─ pred:-0.096431 | └─(neg)─ pred:-0.154996 └─(neg)─ pred:-0.0401556
We can compare this model to a model trained with RMSE loss.
model.evaluate(test_ds)
Evaluation of regression models
- RMSE (Root Mean Squared Error)
- The square root of the average squared difference between predictions and ground truth values.
Interpretation: Lower RMSE is better. It has the same units as the target variable, making it somewhat interpretable. - Residual
- The difference between a prediction and the ground truth per example (Prediction - Ground Truth).
- Residual Histogram
- A histogram showing the distribution of the residuals.
Interpretation: Ideally, you want a roughly symmetrical, bell-shaped distribution centered around zero, indicating that the errors are random and not biased. - Ground Truth Histogram
- A histogram showing the distribution of the actual target values in your dataset.
- Prediction Histogram
- A histogram showing the distribution of the model's predictions.
- Ground Truth vs Predictions Curve
- A scatter plot where each point represents a data point. The x-axis is the ground truth value, and the y-axis is the model's prediction.
Interpretation: A perfect model would have all points falling on a diagonal line (where prediction = ground truth). Deviations from this line show errors. - Predictions vs Residual Curve
- A scatter plot where the x-axis is the model's prediction, and the y-axis is the residual.
Interpretation: Ideally, you want to see a random scatter of points around the horizontal line at zero. Patterns (e.g., a funnel shape) might indicate problems with the model. - Predictions vs Ground Truth Curve
- Sometimes this will plot a fitted curve through the points on the Ground Truth vs Predictions scatter plot to visualize the trend. It can help to see if the model is systematically over- or under-predicting in certain ranges.
# A model trained with default regression loss (i.e. RMSE loss)
model_rmse_loss = ydf.GradientBoostedTreesLearner(label="Rings", task=ydf.Task.REGRESSION).train(train_ds)
model_rmse_loss.evaluate(test_ds)
Train model on 2923 examples Model trained in 0:00:01.017847
Evaluation of regression models
- RMSE (Root Mean Squared Error)
- The square root of the average squared difference between predictions and ground truth values.
Interpretation: Lower RMSE is better. It has the same units as the target variable, making it somewhat interpretable. - Residual
- The difference between a prediction and the ground truth per example (Prediction - Ground Truth).
- Residual Histogram
- A histogram showing the distribution of the residuals.
Interpretation: Ideally, you want a roughly symmetrical, bell-shaped distribution centered around zero, indicating that the errors are random and not biased. - Ground Truth Histogram
- A histogram showing the distribution of the actual target values in your dataset.
- Prediction Histogram
- A histogram showing the distribution of the model's predictions.
- Ground Truth vs Predictions Curve
- A scatter plot where each point represents a data point. The x-axis is the ground truth value, and the y-axis is the model's prediction.
Interpretation: A perfect model would have all points falling on a diagonal line (where prediction = ground truth). Deviations from this line show errors. - Predictions vs Residual Curve
- A scatter plot where the x-axis is the model's prediction, and the y-axis is the residual.
Interpretation: Ideally, you want to see a random scatter of points around the horizontal line at zero. Patterns (e.g., a funnel shape) might indicate problems with the model. - Predictions vs Ground Truth Curve
- Sometimes this will plot a fitted curve through the points on the Ground Truth vs Predictions scatter plot to visualize the trend. It can help to see if the model is systematically over- or under-predicting in certain ranges.
Other custom losses¶
Binary Classification¶
For binary classification problems, the labels are integers (1 for the positive class and 2 for the negative class). The model is expected to return the probability of the positive class. YDF supports the Sigmoid activation function for losses that do not operate in the probability space.
For demonstration purposes, the code below re-implements the
Binomial Log Likelihood Loss as a custom loss.
Note that this loss is also available directly through the
loss=BINOMIAL_LOG_LIKELIHOOD
hyperparameter.
def binomial_initial_predictions(
labels: npty.NDArray[np.int32], weights: npty.NDArray[np.float32]
) -> np.float32:
sum_weights = np.sum(weights)
sum_weights_positive = np.sum((labels == 2) * weights)
ratio_positive = sum_weights_positive / sum_weights
if ratio_positive == 0.0:
return -np.iinfo(np.float32).max
elif ratio_positive == 1.0:
return np.iinfo(np.float32).max
return np.log(ratio_positive / (1 - ratio_positive))
def binomial_gradient_and_hessian(
labels: npty.NDArray[np.int32], predictions: npty.NDArray[np.float32]
) -> Tuple[npty.NDArray[np.float32], npty.NDArray[np.float32]]:
pred_probability = 1.0 / (1.0 + np.exp(-predictions))
binary_labels = labels == 2
return (
pred_probability - binary_labels,
pred_probability * (pred_probability - 1),
)
def binomial_loss(
labels: npty.NDArray[np.int32],
predictions: npty.NDArray[np.float32],
weights: npty.NDArray[np.float32],
) -> np.float32:
binary_labels = labels == 2
return (-2.0 * np.sum(
binary_labels * predictions- np.log(1.0 + np.exp(predictions))
) / len(labels)
)
binomial_custom_loss = ydf.BinaryClassificationLoss(
initial_predictions=binomial_initial_predictions,
gradient_and_hessian=binomial_gradient_and_hessian,
loss=binomial_loss,
activation=ydf.Activation.SIGMOID,
)
Multi-class classification¶
For multi-class classification problems, the labels are integers starting with 1. The loss function must provide a gradient and hessian for each label class. The gradient and hessian must return d-by-n matrices, where n is the number of examples and d is the number of label classes. Similarly, the model must provide an initial prediction for each label class as as a vector of d elements.
YDF supports the Softmax activation function for losses that do not operate in the probability space.
For demonstration purposes, the code below re-implements the
Multinomial Log Likelihood Loss as a custom loss. Note that this loss is
also available directly through the loss=MULTINOMIAL_LOG_LIKELIHOOD
hyperparameter.
def multinomial_initial_predictions(
labels: npty.NDArray[np.int32], _: npty.NDArray[np.float32]
) -> npty.NDArray[np.float32]:
dimension = np.max(labels)
return np.zeros(dimension, dtype=np.float32)
def multinomial_gradient(
labels: npty.NDArray[np.int32], predictions: npty.NDArray[np.float32]
) -> Tuple[npty.NDArray[np.float32], npty.NDArray[np.float32]]:
dimension = np.max(labels)
normalization = 1.0 / np.sum(np.exp(predictions), axis=1)
normalized_predictions = np.exp(predictions) * normalization[:, None]
label_indicator = (
(labels - 1)[:, np.newaxis] == np.arange(dimension)
).astype(int)
gradient = normalized_predictions - label_indicator
hessian = np.abs(gradient) * (np.abs(gradient) - 1)
return (np.transpose(gradient), np.transpose(hessian))
def multinomial_loss(
labels: npty.NDArray[np.int32],
predictions: npty.NDArray[np.float32],
weights: npty.NDArray[np.float32],
) -> np.float32:
dimension = np.max(labels)
sum_exp_pred = np.sum(np.exp(predictions), axis=1)
indicator_matrix = (
(labels - 1)[:, np.newaxis] == np.arange(dimension)
).astype(int)
label_exp_pred = np.exp(np.sum(predictions * indicator_matrix, axis=1))
return (
-np.sum(np.log(label_exp_pred / sum_exp_pred)) / len(labels)
)
multinomial_custom_loss = ydf.MultiClassificationLoss(
initial_predictions=multinomial_initial_predictions,
gradient_and_hessian=multinomial_gradient,
loss=multinomial_loss,
activation=ydf.Activation.SOFTMAX,
)
Custom losses with JAX¶
JAX allows defining losses with auto-differentiation. In this example, we define the Huber loss for Regression.
import jax
import jax.numpy as jnp
@jax.jit
def huber_loss(labels, pred, delta=1.0):
abs_diff = jnp.abs(labels - pred)
return jnp.average(jnp.where(abs_diff > delta,delta * (abs_diff - .5 * delta), 0.5 * abs_diff ** 2))
huber_grad = jax.jit(jax.grad(huber_loss, argnums=1))
huber_hessian = jax.jit(jax.jacfwd(jax.jacrev(huber_loss, argnums=1)))
huber_init = jax.jit(lambda labels, weights: jnp.average(labels))
huber = ydf.RegressionLoss(
initial_predictions=jax.block_until_ready(huber_init),
gradient_and_hessian=lambda label, pred: (
huber_grad(label, pred).block_until_ready(),
jnp.diagonal(huber_hessian(label, pred)).block_until_ready()
),
loss=lambda label, pred, weight: huber_loss(label, pred).block_until_ready(),
activation=ydf.Activation.IDENTITY,
)
model = ydf.GradientBoostedTreesLearner(label="Rings", task=ydf.Task.REGRESSION, loss=huber).train(train_ds)
Train model on 2923 examples Using a custom loss. Note when using custom losses, hyperparameter `apply_link_function` is ignored. Use the losses' activation function instead.
INFO:2025-02-11 10:26:09,072:jax._src.xla_bridge:924: Unable to initialize backend 'cuda': Your process properly initialized the GPU backend, but //learning/brain/research/jax:gpu_support is not linked in. You most likely should add that build dependency to your program. INFO:2025-02-11 10:26:09,073:jax._src.xla_bridge:924: Unable to initialize backend 'rocm': Your process properly initialized the GPU backend, but //learning/brain/research/jax:gpu_support is not linked in. You most likely should add that build dependency to your program. INFO:2025-02-11 10:26:09,073:jax._src.xla_bridge:924: Unable to initialize backend 'tpu': No TPU backend found. Make sure //learning/brain/research/jax:tpu_support is included in your deps. WARNING:2025-02-11 10:26:09,074:jax._src.xla_bridge:966: An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
Model trained in 0:00:02.142357
Additional details and tips¶
- For simplicity of exposition, the examples above assume unit weights.
- Loss functions should not create references to the labels, predictions and weights arrays. These arrays are backed by C++ memory and might be deleted on the C++ side at any time.
- When using custom losses, YDF may trigger the GC to catch illegal memory accesses. Set
may_trigger_gc=False
on the loss object to avoid this, but be aware that YDF may not warn about illegal memory accesses then. - The arrays returned by the custom loss functions may be modified by YDF.
- Training with custom losses is often ~10% slower than training built-in losses.
- Custom losses are not fully supported for model inspection and analysis - it is not yet possible to compute the model's custom loss on a test set in YDF.