趋近智
itertools 处理复杂序列__getattr__, __getattribute__)multiprocessing 模块concurrent.futures 实现高级并发构建与Scikit-learn生态系统结合的自定义评估器或转换器,确保其可靠性并提供清晰的用户体验具有很大的意义。这方面一个主要的方面是彻底验证组件接受的参数。正如标准Scikit-learn组件会检查其输入一样,你的自定义类也应如此。这有助于防止意外的运行时错误,引导用户正确使用,并使你的组件更可靠、更易于调试。
参数验证主要在自定义类的__init__方法中进行,尽管某些检查可能会推迟到fit或transform等方法中,当参数与数据的交互变得相关时。目标是尽可能早地发现无效输入。
ValueError: 'solver' must be one of ['svd', 'lsqr', 'eigen'], got 'foo',比源自矩阵分解失败的复杂追溯信息更有帮助。良好的验证可以起到交互式文档的作用。fit或transform的核心逻辑上,而不是质疑输入。你的参数应该考虑以下几种检查类型:
int、float、str、list、callable、bool、None)?为此使用isinstance()。请注意可接受的变体,例如使用isinstance(param, (int, float))允许数值参数为整数或浮点数。n_neighbors > 0,alpha >= 0.0,ratio between (0, 1))。penalty in {'l1', 'l2', 'elasticnet'})。True、False或None。callable()检查其是否存在,如果需要特定参数,则可能需要使用inspect模块检查其签名。method='A',可能必须提供param_for_A,而param_for_B必须是None。这些检查通常涉及在__init__中检查多个self属性。最直接的方法是在__init__方法中直接使用条件逻辑。
import numpy as np
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.utils.validation import check_is_fitted
class CustomScaler(BaseEstimator, TransformerMixin):
"""
一个自定义的缩放器,按常数因子缩放特征。
参数
----------
scale_factor : float, default=1.0
用于特征乘法的因子。必须为正数。
strategy : {'multiply', 'divide'}, default='multiply'
是乘以还是除以 scale_factor。
"""
def __init__(self, scale_factor=1.0, strategy='multiply'):
# 手动类型和值验证
if not isinstance(scale_factor, (int, float)):
raise TypeError(f"Parameter 'scale_factor' must be numeric, got {type(scale_factor).__name__}")
if scale_factor <= 0:
raise ValueError(f"Parameter 'scale_factor' must be positive, got {scale_factor}")
allowed_strategies = {'multiply', 'divide'}
if strategy not in allowed_strategies:
raise ValueError(f"Parameter 'strategy' must be one of {allowed_strategies}, got '{strategy}'")
self.scale_factor = scale_factor
self.strategy = strategy
def fit(self, X, y=None):
# 这个简单的转换器不需要拟合
# 可选:在此处使用 check_array 添加 X 的输入验证
# from sklearn.utils.validation import check_array
# X = check_array(X)
self._n_features_in = X.shape[1]
return self
def transform(self, X):
check_is_fitted(self)
# 可选:在此处添加 X 的输入验证
# X = check_array(X)
if X.shape[1] != self._n_features_in:
raise ValueError(f"Input has {X.shape[1]} features, expected {self._n_features_in}")
if self.strategy == 'multiply':
return X * self.scale_factor
elif self.strategy == 'divide':
# 如果 scale_factor 可能为零,则添加除零检查(已通过 init 验证阻止)
return X / self.scale_factor
这对于简单的情况来说很清晰,但如果参数很多且约束复杂,可能会变得冗长且重复。
Scikit-learn 提供了一种更结构化、声明性的方式来处理参数验证,即使用validate_parameter_constraints装饰器和一个名为_parameter_constraints的类级别字典。这种方法提升了一致性并减少了__init__中的样板代码。
你可以在_parameter_constraints字典中定义每个参数的约束。键是参数名称,值是允许的类型列表或约束对象(如Interval、StrOptions)。
import numpy as np
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.utils.validation import check_is_fitted, check_array
# 导入约束类和装饰器
from sklearn.utils._param_validation import validate_parameter_constraints, Interval, StrOptions
import numbers # 使用 numbers.Real 以包含更广泛的数值类型
class CustomScalerValidated(BaseEstimator, TransformerMixin):
"""
一个自定义的缩放器,按常数因子缩放特征。
使用 validate_parameter_constraints 进行验证。
参数
----------
scale_factor : float, default=1.0
用于特征乘法的因子。必须为正数。
strategy : {'multiply', 'divide'}, default='multiply'
是乘以还是除以 scale_factor。
"""
# 使用特殊类属性定义约束
_parameter_constraints: dict = {
"scale_factor": [Interval(numbers.Real, 0, None, closed="neither")], # 必须是实数,且 > 0
"strategy": [StrOptions({"multiply", "divide"})], # 必须是这些字符串之一
}
# 将装饰器应用于类
@validate_parameter_constraints
def __init__(self, scale_factor=1.0, strategy='multiply'):
# 这里不需要显式的验证代码!
# 装饰器会根据 _parameter_constraints 处理验证。
self.scale_factor = scale_factor
self.strategy = strategy
def fit(self, X, y=None):
X = check_array(X, ensure_2d=True, dtype=np.float64) # 在 fit 中验证 X 输入
self._n_features_in = X.shape[1]
# 在此处存储任何拟合属性(此简单缩放器不需要)
# self.mean_ = X.mean(axis=0) # 如果需要拟合的例子
return self
def transform(self, X):
check_is_fitted(self) # 检查是否已调用 fit
X = check_array(X, ensure_2d=True, dtype=np.float64) # 在 transform 中验证 X 输入
if X.shape[1] != self._n_features_in:
raise ValueError(f"Input has {X.shape[1]} features, but scaler was fitted with {self._n_features_in} features.")
if self.strategy == 'multiply':
return X * self.scale_factor
elif self.strategy == 'divide':
# 约束确保 scale_factor > 0,因此不会出现除零错误
return X / self.scale_factor
# 验证如何工作的例子:
try:
scaler = CustomScalerValidated(scale_factor=-2.0)
except ValueError as e:
print(f"Validation Error: {e}")
# 预期输出: Validation Error: CustomScalerValidated 的 'scale_factor' 参数必须是严格的正实数。得到的是 -2.0。
try:
scaler = CustomScalerValidated(strategy='add')
except ValueError as e:
print(f"Validation Error: {e}")
# 预期输出: Validation Error: CustomScalerValidated 的 'strategy' 参数必须是 {'multiply', 'divide'} 中的一个字符串。得到的是 'add'。
validate_parameter_constraints装饰器会检查传递给__init__的参数,将它们与_parameter_constraints中定义的规则进行比较,如果它们不匹配,则引发信息性的TypeError或ValueError异常。这是现代Scikit-learn兼容组件的推荐方法。可用的约束包括Interval(用于数值范围)、StrOptions(用于字符串选项)、Options(通用集合成员)、类型提示(如int、float、list、callable、np.ndarray)以及None。你还可以在列表中组合约束(例如,[StrOptions({"auto", "manual"}), None]以允许特定的字符串或None)。
__init__中,始终将用户传递的参数直接存储为同名属性(例如,self.scale_factor = scale_factor)。Scikit-learn的get_params和set_params等工具依赖于此。get_params 和 set_params: 确保你的评估器继承自BaseEstimator(或包含相应的混合类)。这提供了默认的get_params和set_params方法,对于超参数调整(如GridSearchCV)和克隆非常重要。这些方法通过检查__init__签名并访问具有相应名称的属性来工作。__init__中执行初始类型和基本值检查(理想情况下使用validate_parameter_constraints)。X的检查(如形状兼容性)推迟到fit或transform方法。在此处使用Scikit-learn的数据验证工具,如check_array和check_X_y。transform、predict等方法的开头使用check_is_fitted,以确保fit已被调用。__init__文档字符串中记录每个参数,清晰说明其类型、用途、允许的范围或选项以及默认值。这补充了程序化验证。通过实现全面的参数验证与管理,你可以创建功能正确、用户友好并与机器学习工具包良好集成的自定义Scikit-learn组件。使用validate_parameter_constraints是根据Scikit-learn最佳实践高效且一致地实现这一点的首选方法。
这部分内容有帮助吗?
_parameter_constraints和validate_parameter_constraints进行参数验证,以及通用API约定。check_array、check_X_y和check_is_fitted等工具的官方文档,这些工具对于在自定义组件中验证输入数据和估计器状态至关重要。© 2026 ApX Machine Learning用心打造