趋近智
Keras 提供了全面的一套内置指标 (tf.keras.metrics),但经常会出现需要使用针对特定应用场景或研究问题定义的标准来评估模型性能的情况。例如,你可能需要追踪多类别问题中某个特定类别的 F1 分数,计算特定领域的错误率,或实现最新研究论文中提出的新评估方法。这时,开发自定义指标就变得十分必要,它提供了所需的灵活性。
在构建自己的指标之前,了解 Keras 指标的工作方式非常重要。与为单个批次计算标量值的简单损失函数 (loss function)不同,指标通常需要在同一个训练周期内累积多个批次的信息,才能提供有意义的汇总结果。以准确率为例:按批次计算可能不稳定;你通常希望得到在当前训练周期内所有已处理样本的整体准确率。
为了处理这种情况,Keras 指标被实现为有状态对象。它们维护内部状态变量,这些变量会根据每个批次的数据进行增量更新。在每个训练周期结束时(或评估期间),最终指标值会根据这些累积的状态计算得出。
每个 Keras 指标,无论是内置的还是自定义的,通常都继承自基类 tf.keras.metrics.Metric,并实现了四个主要方法:
__init__(self, name='my_metric', **kwargs):构造函数。在这里,你初始化计算指标所需的状态变量。重要的一点是,状态变量应使用 self.add_weight() 方法创建。这可确保它们被 TensorFlow 正确追踪、在不同执行模式(即时执行与图模式)下管理,并在分布式训练场景中同步。update_state(self, y_true, y_pred, sample_weight=None):此方法处理单个批次的标签 (y_true) 和预测值 (y_pred),并相应地更新内部状态变量。这是累积统计数据的核心逻辑。sample_weight 允许对样本进行可选加权。result(self):此方法使用存储在状态变量中的值来计算并以 tf.Tensor 形式返回最终指标值。它不应修改状态。reset_state(self):此方法将所有状态变量重置回其初始值。Keras 在 model.fit() 期间每个训练周期开始时以及 model.evaluate() 开始时会自动调用此方法。我们通过创建一个自定义指标来说明这一点,该指标用于计算多类别分类问题中特定类别的真阳性数量。这对于关注特定类别的性能可能很有用。
import tensorflow as tf
class CategoricalTruePositives(tf.keras.metrics.Metric):
"""
计算特定目标类别的真阳性数量。
参数:
target_class_id: 整数,要计算真阳性的类别 ID。
name: 字符串,指标实例的名称。
dtype: 指标结果的数据类型。
"""
def __init__(self, target_class_id, name='categorical_true_positives', dtype=tf.int32, **kwargs):
super().__init__(name=name, dtype=dtype, **kwargs)
self.target_class_id = target_class_id
# 使用 add_weight 初始化状态变量
self.true_positives = self.add_weight(
name='tp',
initializer='zeros',
dtype=self.dtype # 使用指标的 dtype
)
def update_state(self, y_true, y_pred, sample_weight=None):
# 确保输入是张量
y_true = tf.cast(y_true, tf.int32)
y_pred = tf.cast(tf.argmax(y_pred, axis=-1), tf.int32) # 获取预测的类别 ID
# 识别目标类别的真阳性
is_target_class = tf.equal(y_true, self.target_class_id)
is_prediction_correct = tf.equal(y_true, y_pred)
# 逻辑与运算以找到目标类别的真阳性
batch_true_positives = tf.logical_and(is_target_class, is_prediction_correct)
batch_true_positives = tf.cast(batch_true_positives, self.dtype)
# 如果提供了样本权重,则进行处理
if sample_weight is not None:
sample_weight = tf.cast(sample_weight, self.dtype)
# 确保权重形状可广播
sample_weight = tf.broadcast_to(sample_weight, tf.shape(batch_true_positives))
batch_true_positives = batch_true_positives * sample_weight
# 更新状态变量
current_sum = tf.reduce_sum(batch_true_positives)
self.true_positives.assign_add(current_sum)
def result(self):
# 返回累积计数
return self.true_positives
def reset_state(self):
# 将状态变量重置为零
self.true_positives.assign(0)
# 可选:用于保存/加载的配置
def get_config(self):
config = super().get_config()
config.update({'target_class_id': self.target_class_id})
return config
在此示例中:
__init__ 存储 target_class_id 并使用 self.add_weight 将单个状态变量 self.true_positives 初始化为零。update_state 接收批次预测和标签,使用 tf.argmax 确定预测类别,检查哪些样本属于 target_class_id 且被正确预测,并相应地增加 self.true_positives 计数,可选地考虑样本权重 (weight)。所有操作都使用 TensorFlow 函数以确保图兼容性。result 简单地返回 self.true_positives 的当前值。reset_state 将 self.true_positives 重置为 0。get_config 的添加是为了让指标(包括其 target_class_id)能够随模型正确保存和加载。将自定义指标集成到你的工作流程中非常直接。你可以实例化它并将其传递给 model.compile() 中的 metrics 列表:
# 假设 model 是一个已定义的多类别分类 Keras 模型
# 假设 num_classes = 10
model.compile(
optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=[
'accuracy', # 标准指标
CategoricalTruePositives(target_class_id=3, name='true_positives_class_3'), # 类别 3 的自定义指标
CategoricalTruePositives(target_class_id=7, name='true_positives_class_7') # 类别 7 的另一个实例
]
)
# 现在你可以训练或评估模型
# history = model.fit(train_dataset, epochs=5, validation_data=val_dataset)
# results = model.evaluate(test_dataset)
# 结果字典和 history 对象将包含以下值:
# 'true_positives_class_3' 和 'true_positives_class_7'
Keras 将在训练和评估循环期间自动管理指标的状态更新和重置。
update_state 和 result 中的所有操作都使用 TensorFlow 函数 (tf.*)。如果你打算使用 tf.function(Keras 默认会这样做),请避免使用 AutoGraph 无法转换的 Python 循环或条件逻辑。这些方法中的 NumPy 操作或纯 Python 逻辑很可能在图模式下导致错误或性能问题。tf.cast 以避免类型不匹配错误。使用适当的 dtype 初始化权重 (weight) (add_weight)。sample_weight 或比较预测和标签时。请仔细使用 tf.shape、tf.rank 和广播规则。self.add_weight 对于分布式训练非常重要。TensorFlow 的分布式策略依赖此机制来正确聚合不同设备或工作器上的指标状态。简单的 Python 属性将不会同步。__init__ 中适当初始化状态变量,并确保 reset_state 将它们正确恢复到此初始状态。通过熟练创建自定义指标,你可以精确控制模型性能的衡量方式,从而实现更具洞察力的模型评估和针对特定目标的开发。
这部分内容有帮助吗?
tf.function 如何将 Python 代码转换为可调用的 TensorFlow 图的信息,这对于 Keras 自定义组件的性能和兼容性很重要。add_weight 管理)。© 2026 ApX Machine LearningAI伦理与透明度•