趋近智
itertools 处理复杂序列__getattr__, __getattribute__)multiprocessing 模块concurrent.futures 实现高级并发开发定制估计器和变换器提供了很大的灵活性,但这种能力伴随着确保正确性和兼容性的责任。就像任何软件组件一样,你的定制Scikit-learn对象需要全面测试,以保证它们在机器学习体系中可靠运行。未经测试的组件可能导致不易发现的错误、不正确的结果以及难以调试的管道故障。
严格测试可确认你的组件:
Pipeline、GridSearchCV 和 cross_validate 等工具的兼容性。fit 正确设置 transform 或 predict 使用的属性,且不会不恰当地泄露状态。单元测试和集成测试的组合为你的定制组件提供了全面覆盖。
fit、transform、predict、辅助函数)。你提供受控输入并断言预期输出或内部状态。单元测试速度快,有助于查明代码特定部分中的错误。pytest 等库非常适合在 Python 中编写清晰且易于维护的单元测试。Pipeline 内)是否正常工作。这用于检查兼容性问题并确保数据在步骤之间正确流动。check_estimatorScikit-learn提供了一个强大的工具,专门用于验证定制估计器或变换器是否符合标准API约定:sklearn.utils.estimator_checks.check_estimator。
此函数运行一套全面的预定义测试,涵盖以下方面:
fit、transform、predict、predict_proba 等的正确实现。get_params、set_params)。你可以轻松地将 check_estimator 整合到你的测试套件中,通常使用 pytest。
# Example test file (e.g., test_my_transformer.py)
import pytest
from sklearn.utils.estimator_checks import check_estimator
from my_module import MyCustomTransformer # Your custom class
# Basic pytest integration
def test_check_estimator_compliance():
"""Check if MyCustomTransformer adheres to Scikit-learn conventions."""
check_estimator(MyCustomTransformer())
# You can also instantiate with specific parameters
def test_check_estimator_with_params():
"""Check compliance with non-default parameters."""
estimator = MyCustomTransformer(parameter1='value', parameter2=10)
check_estimator(estimator)
在终端运行 pytest 将执行这些检查。如果任何约定被违反,check_estimator 将引发信息丰富的 AssertionError 异常。尽管争取通过所有检查是理想情况,但某些检查可能不适用于高度专业化的组件。在这种罕见情况下,你可能需要调查具体的失败原因,并在有充分理由时考虑跳过某个检查,并清晰地记录下来。
在进行单元测试时,侧重于:
fit 方法: 验证 fit 是否从训练数据中正确学习了必要参数并将其存储为属性(通常以下划线 _ 结尾)。检查它是否返回 self。transform 方法: 确保 transform 正确使用拟合属性来修改输入数据。验证输出形状和数据类型。测试在拟合后对未见数据调用 transform 是否按预期工作。确保 transform 不会修改在 fit 期间学习的内部状态。fit_transform 方法: Scikit-learn 通常会优化 fit_transform。虽然默认实现只是简单地调用 fit 然后调用 transform,但定制实现可能会提供性能优势。测试你的定制 fit_transform(如果提供)是否与先调用 fit 后调用 transform 产生相同的结果。对于定制估计器,测试内容包括:
fit 方法: 验证 fit 是否从训练数据 X 和标签 y 中正确学习模型参数。检查是否设置了所需的拟合属性(例如,coef_、intercept_、classes_)。确保它返回 self。predict 方法: 测试 predict 是否使用拟合属性为新数据 X 生成预测。验证输出数组的形状和类型。使用多种输入进行测试,包括单个样本和多个样本。predict_proba、decision_function 或 score,请全面测试这些方法,检查输出形状、类型和值范围(例如,概率在 0 到 1 之间)。fit 期间设置的属性得到正确运用,并且不会被预测方法更改。检查 feature_names_in_ 和 n_features_in_ 在 fit 期间是否正确设置。强烈建议使用 pytest 这样的测试运行器。
pytest 夹具非常适合设置可复用的测试数据(例如,样本 X 和 y 数组或 DataFrame),并以受控方式实例化你的定制组件。# Example using pytest fixture
import pytest
import numpy as np
from my_module import MyCustomTransformer
@pytest.fixture
def sample_data():
"""Provides sample data for testing."""
X = np.array([[1, 2], [3, 4], [5, 6]])
return X
def test_transformer_output_shape(sample_data):
"""Check if the transformer maintains the number of samples."""
transformer = MyCustomTransformer()
transformer.fit(sample_data)
X_transformed = transformer.transform(sample_data)
assert X_transformed.shape[0] == sample_data.shape[0]
# Add more specific shape assertions based on the transformer's logic
X.shape = (0, n_features))。X.shape = (1, n_features))。X.shape = (n_samples, 1))。NaN 值或无限值的数据(如果适用)。典型的测试流程涉及多个检查层:
一种针对定制Scikit-learn组件的分层测试方法,从单元测试和API检查开始,随后进行管道内的集成测试。
transform 或 predict 只依赖于在 fit 期间设置的属性。避免在这些方法中修改实例属性。每次调用 fit 都应该根据新数据重置学习到的状态。check_estimator 大量测试通过 sklearn.base.clone 进行的克隆。确保你的 __init__ 方法只接受以名称存储为属性的参数,并且 get_params/set_params 能正确处理这些参数。避免在 __init__ 中使用复杂逻辑或数据验证;在 fit 中执行验证。random_state 参数并正确使用它(例如,为NumPy的随机数生成器设置种子)使测试具有确定性。_) 结尾,不应在 __init__ 中设置。与 __init__ 参数对应的属性不应有尾部下划线。通过采用系统性的测试方法,包括单元测试、运用 check_estimator 和执行集成测试,你将对定制组件建立信心,使你的高级机器学习管道更加可靠和易于维护。
这部分内容有帮助吗?
check_estimator 进行 API 验证和遵守。© 2026 ApX Machine Learning用心打造