While Keras provides a comprehensive suite of standard loss functions in tf.keras.losses
, suitable for many common machine learning tasks, you'll inevitably encounter scenarios where these aren't sufficient. You might need to:
TensorFlow offers straightforward ways to define your own custom loss functions, integrating them seamlessly into the standard Keras training workflow (model.compile()
, model.fit()
). This section explores how to create these custom losses, giving you precise control over the objective your model optimizes.
There are two primary ways to define a custom loss:
tf.keras.losses.Loss
: Preferred for losses that require hyperparameters, internal state, or need custom serialization logic.The most direct way to create a custom loss is by defining a Python function that accepts two arguments: y_true
(the ground truth labels) and y_pred
(the model's predictions). This function should return a tensor containing the loss value for each sample in the batch. Keras handles the aggregation (reduction) of these per-sample losses based on the reduction
argument you might specify later or defaults.
The function must use TensorFlow operations exclusively to ensure it can be traced into a graph by tf.function
and run efficiently on accelerators like GPUs or TPUs.
Let's implement the Huber loss, a function less sensitive to outliers than Mean Squared Error (MSE). It behaves quadratically for small errors and linearly for large errors. The formula is:
Lδ(y,f(x))={21(y−f(x))2δ(∣y−f(x)∣−21δ)for ∣y−f(x)∣≤δotherwiseHere, y represents y_true
, f(x) represents y_pred
, and δ is a threshold parameter.
import tensorflow as tf
def huber_loss(y_true, y_pred, delta=1.0):
"""
Calculates the Huber loss between y_true and y_pred.
Args:
y_true: Ground truth values. shape = [batch_size, d0, .. dN]
y_pred: The predicted values. shape = [batch_size, d0, .. dN]
delta: The point where the loss changes from quadratic to linear.
Returns:
A tf.Tensor containing the Huber loss values for each sample.
shape = [batch_size, d0, .. dN-1]
"""
y_true = tf.cast(y_true, dtype=y_pred.dtype) # Ensure types match
error = y_true - y_pred
abs_error = tf.abs(error)
quadratic = tf.minimum(abs_error, delta)
linear = abs_error - quadratic
return 0.5 * tf.square(quadratic) + delta * linear
# --- Usage Example ---
# Assuming you have a compiled Keras model:
# model.compile(optimizer='adam', loss=huber_loss)
# You can also pass configuration via functools.partial or lambda
# from functools import partial
# model.compile(optimizer='adam', loss=partial(huber_loss, delta=0.8))
# model.compile(optimizer='adam', loss=lambda yt, yp: huber_loss(yt, yp, delta=0.8))
# Keras will automatically handle the reduction (usually averaging over the batch)
This function-based approach is clean and simple for straightforward, stateless calculations. However, if your loss function requires configurable parameters (like delta
above, though we hardcoded a default or used wrappers) or needs to maintain state across batches (less common for losses), subclassing tf.keras.losses.Loss
is more robust. Using wrappers like lambda
or partial
can also hinder model serialization.
tf.keras.losses.Loss
For greater flexibility, maintainability, and to handle hyperparameters cleanly, you can create a custom loss by subclassing tf.keras.losses.Loss
. This object-oriented approach allows you to:
__init__
).get_config
.To subclass tf.keras.losses.Loss
, you primarily need to:
__init__
method to accept any hyperparameters and call the parent constructor (super().__init__(...)
). You can specify the default reduction
strategy here (e.g., SUM_OVER_BATCH_SIZE
, SUM
, NONE
) and provide a name
for the loss.call(self, y_true, y_pred)
method. This is where the core loss calculation logic resides, similar to the function-based approach. It receives y_true
and y_pred
and should return the per-sample loss tensor.Let's implement Focal Loss, often used in object detection or classification tasks with extreme class imbalance. It down-weights the contribution of easy examples, allowing the model to focus on hard-to-classify examples.
The formula for binary Focal Loss is:
FL(pt)=−αt(1−pt)γlog(pt)where:
import tensorflow as tf
import tensorflow.keras.backend as K
class FocalLoss(tf.keras.losses.Loss):
"""
Implements the Focal Loss function for binary classification.
Args:
alpha: Weighting factor for balancing positive/negative classes.
Float in [0, 1]. Defaults to 0.25.
gamma: Focusing parameter. Non-negative float. Defaults to 2.0.
reduction: Type of tf.keras.losses.Reduction to apply.
Defaults to SUM_OVER_BATCH_SIZE.
name: Optional name for the loss instance. Defaults to 'focal_loss'.
"""
def __init__(self, alpha=0.25, gamma=2.0,
reduction=tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE,
name='focal_loss'):
super().__init__(reduction=reduction, name=name)
self.alpha = alpha
self.gamma = gamma
def call(self, y_true, y_pred):
"""
Calculates the focal loss.
Args:
y_true: True labels (binary 0 or 1). Shape [batch_size, 1] or [batch_size].
y_pred: Predicted probabilities. Shape [batch_size, 1] or [batch_size].
Returns:
Loss tensor with shape compatible with the reduction strategy.
"""
y_true = tf.cast(y_true, dtype=y_pred.dtype)
# Ensure predictions are probabilities (e.g., apply sigmoid if inputs are logits)
# For stability, clip predictions to avoid log(0)
epsilon = K.epsilon()
y_pred = tf.clip_by_value(y_pred, epsilon, 1. - epsilon)
# Calculate p_t
p_t = tf.where(tf.equal(y_true, 1), y_pred, 1. - y_pred)
# Calculate alpha_t
alpha_t = tf.where(tf.equal(y_true, 1), self.alpha, 1. - self.alpha)
# Calculate focal loss components
cross_entropy = -tf.math.log(p_t)
weight = alpha_t * tf.pow(1. - p_t, self.gamma)
# Compute the final focal loss
loss = weight * cross_entropy
return loss # Reduction is handled by the base Loss class
def get_config(self):
"""Returns the serializable config dictionary."""
base_config = super().get_config()
return {**base_config, "alpha": self.alpha, "gamma": self.gamma}
# --- Usage Example ---
# model.compile(optimizer='adam', loss=FocalLoss(alpha=0.3, gamma=1.5))
# Model saving will work correctly because of get_config
# loaded_model = tf.keras.models.load_model(
# 'my_model.keras',
# custom_objects={'FocalLoss': FocalLoss}
# )
tf
operations (tf.math
, tf.where
, tf.cast
, etc.) within your loss function, whether it's a simple function or a Loss
subclass. This ensures compatibility with graph execution, automatic differentiation, and hardware acceleration. Avoid using NumPy or pure Python operations that TensorFlow cannot track.y_true
and y_pred
. They usually need to be compatible for element-wise operations. Use tf.debugging.assert_shapes
or print shapes during debugging if you encounter errors.NaN
or Inf
values, such as log(0)
or division by zero. Use tf.clip_by_value
to constrain inputs to safe ranges or add a small epsilon (tf.keras.backend.epsilon()
) where appropriate, as shown in the FocalLoss
example.SUM_OVER_BATCH_SIZE
) or the one specified in model.compile(..., loss=custom_loss(reduction=...))
(if the loss function is designed to accept it, which is less common). When subclassing tf.keras.losses.Loss
, the reduction strategy is handled automatically based on the reduction
argument passed to the parent __init__
. The call
method should return per-sample losses.model.save()
and load it later, subclassing tf.keras.losses.Loss
and implementing get_config
is essential. When loading the model, you'll need to pass the custom loss class to the custom_objects
argument of tf.keras.models.load_model
. Simple Python functions used as losses are generally not serialized with the model's architecture, requiring you to re-supply the function when loading.model.fit
: Both custom function losses and Loss
subclasses work seamlessly with the standard Keras model.fit
training loop. Keras automatically handles applying the loss, calculating gradients via tf.GradientTape
, and updating model weights.Here's a simple visualization comparing Mean Squared Error (MSE) with Huber loss (δ=1.0). Notice how Huber loss increases linearly for larger errors, making it less influenced by outliers compared to the quadratic increase of MSE.
Comparison of Mean Squared Error (MSE) and Huber loss (δ=1.0) as a function of prediction error.
By implementing custom loss functions, you gain a powerful tool for tailoring your model's training objective precisely to the task at hand, moving beyond standard formulations to address specific challenges and requirements in your machine learning projects.
© 2025 ApX Machine Learning