While Keras provides a comprehensive suite of built-in metrics (tf.keras.metrics
), situations often arise where you need to evaluate model performance using a criterion specific to your application domain or research question. For instance, you might need to track the F1-score for a specific class in a multi-class problem, calculate a domain-specific error rate, or implement a novel evaluation measure from a recent research paper. This is where developing custom metrics becomes essential, offering the flexibility highlighted in this chapter.
Before building your own, it's important to understand how Keras metrics operate. Unlike simple loss functions that compute a scalar value for a single batch, metrics often need to accumulate information across multiple batches within an epoch to provide a meaningful summary. Consider accuracy: calculating it per batch can be noisy; you typically want the overall accuracy across all samples seen so far in the epoch.
To handle this, Keras metrics are implemented as stateful objects. They maintain internal state variables that are updated incrementally with data from each batch. At the end of an epoch (or during evaluation), the final metric value is computed based on these accumulated states.
Every Keras metric, whether built-in or custom, typically inherits from the base class tf.keras.metrics.Metric
and implements four main methods:
__init__(self, name='my_metric', **kwargs)
: The constructor. Here, you initialize the state variables required to compute the metric. Crucially, state variables should be created using the self.add_weight()
method. This ensures they are properly tracked by TensorFlow, managed across different execution modes (eager vs. graph), and synchronized in distributed training scenarios.update_state(self, y_true, y_pred, sample_weight=None)
: This method processes the labels (y_true
) and predictions (y_pred
) for a single batch and updates the internal state variables accordingly. It's the core logic for accumulating statistics. sample_weight
allows optional weighting of samples.result(self)
: This method uses the values stored in the state variables to compute and return the final metric value as a tf.Tensor
. It should not modify the state.reset_state(self)
: This method resets all state variables back to their initial values. Keras calls this automatically at the start of each epoch during model.fit()
and at the start of model.evaluate()
.Let's illustrate this by creating a custom metric to count the number of true positives for a specific class in a multi-class classification problem. This can be useful for monitoring performance on a class of particular interest.
import tensorflow as tf
class CategoricalTruePositives(tf.keras.metrics.Metric):
"""
Computes the number of true positives for a specific target class.
Args:
target_class_id: Integer, the class ID for which to compute true positives.
name: String, name for the metric instance.
dtype: Data type of the metric result.
"""
def __init__(self, target_class_id, name='categorical_true_positives', dtype=tf.int32, **kwargs):
super().__init__(name=name, dtype=dtype, **kwargs)
self.target_class_id = target_class_id
# Initialize state variable using add_weight
self.true_positives = self.add_weight(
name='tp',
initializer='zeros',
dtype=self.dtype # Use the metric's dtype
)
def update_state(self, y_true, y_pred, sample_weight=None):
# Ensure inputs are tensors
y_true = tf.cast(y_true, tf.int32)
y_pred = tf.cast(tf.argmax(y_pred, axis=-1), tf.int32) # Get predicted class ID
# Identify true positives for the target class
is_target_class = tf.equal(y_true, self.target_class_id)
is_prediction_correct = tf.equal(y_true, y_pred)
# Logical AND to find true positives for the target class
batch_true_positives = tf.logical_and(is_target_class, is_prediction_correct)
batch_true_positives = tf.cast(batch_true_positives, self.dtype)
# Handle sample weights if provided
if sample_weight is not None:
sample_weight = tf.cast(sample_weight, self.dtype)
# Ensure weight shape is broadcastable
sample_weight = tf.broadcast_to(sample_weight, tf.shape(batch_true_positives))
batch_true_positives = batch_true_positives * sample_weight
# Update the state variable
current_sum = tf.reduce_sum(batch_true_positives)
self.true_positives.assign_add(current_sum)
def result(self):
# Return the accumulated count
return self.true_positives
def reset_state(self):
# Reset the state variable to zero
self.true_positives.assign(0)
# Optional: Configuration for saving/loading
def get_config(self):
config = super().get_config()
config.update({'target_class_id': self.target_class_id})
return config
In this example:
__init__
stores the target_class_id
and initializes a single state variable self.true_positives
to zero using self.add_weight
.update_state
takes batch predictions and labels, determines the predicted class using tf.argmax
, checks which samples belong to the target_class_id
and were correctly predicted, and increments the self.true_positives
count accordingly, optionally considering sample weights. All operations use TensorFlow functions to ensure graph compatibility.result
simply returns the current value of self.true_positives
.reset_state
resets self.true_positives
to 0.get_config
is added to allow the metric (including its target_class_id
) to be properly saved and loaded with the model.Integrating a custom metric into your workflow is straightforward. You instantiate it and pass it to the metrics
list in model.compile()
:
# Assume model is a defined Keras model for multi-class classification
# Assume num_classes = 10
model.compile(
optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=[
'accuracy', # Standard metric
CategoricalTruePositives(target_class_id=3, name='true_positives_class_3'), # Custom metric for class 3
CategoricalTruePositives(target_class_id=7, name='true_positives_class_7') # Another instance for class 7
]
)
# Now you can train or evaluate the model
# history = model.fit(train_dataset, epochs=5, validation_data=val_dataset)
# results = model.evaluate(test_dataset)
# The results dictionary and history object will contain values for
# 'true_positives_class_3' and 'true_positives_class_7'
Keras will automatically manage the metric's state updates and resets during training and evaluation loops.
update_state
and result
use TensorFlow functions (tf.*
). Avoid Python loops or conditional logic that cannot be translated by AutoGraph if you intend to use tf.function
(which Keras does by default). NumPy operations or pure Python logic within these methods will likely cause errors or performance issues in graph mode.tf.cast
explicitly when necessary to avoid type mismatch errors. Initialize weights (add_weight
) with the appropriate dtype
.sample_weight
or comparing predictions and labels. Use tf.shape
, tf.rank
, and broadcasting rules carefully.self.add_weight
for state variables is fundamental for distributed training. TensorFlow's distribution strategies rely on this mechanism to correctly aggregate metric states across different devices or workers. Simple Python attributes will not be synchronized.__init__
and ensure reset_state
correctly reverts them to this initial state.By mastering the creation of custom metrics, you gain precise control over how your model's performance is measured, enabling more insightful model evaluation and development tailored to your specific goals.
© 2025 ApX Machine Learning