When building custom estimators or transformers that integrate with the Scikit-learn ecosystem, ensuring reliability and providing a clear user experience is very important. A significant part of this involves thoroughly validating the parameters your component accepts. Just as standard Scikit-learn components check their inputs, your custom classes should do the same. This prevents unexpected runtime errors, guides users towards correct usage, and makes your components more reliable and easier to debug.
Parameter validation primarily occurs within the __init__ method of your custom class, although some checks might be deferred until methods like fit or transform when parameter interactions with data become relevant. The goal is to catch invalid inputs as early as possible.
ValueError: 'solver' must be one of ['svd', 'lsqr', 'eigen'], got 'foo' are far more helpful than a complex traceback originating from matrix decomposition failures. Good validation acts as interactive documentation.fit or transform, rather than questioning the inputs.You should consider several types of checks for your parameters:
int, float, str, list, callable, bool, None)? Use isinstance() for this. Be mindful of acceptable variations, like allowing both integers and floats for a numerical parameter using isinstance(param, (int, float)).n_neighbors > 0, alpha >= 0.0, ratio between (0, 1)).penalty in {'l1', 'l2', 'elasticnet'}).True, False, or None are permitted.callable() and potentially its signature using the inspect module if specific arguments are required.method='A', perhaps param_for_A must be provided, while param_for_B must be None. These checks often involve inspecting multiple self attributes within __init__.__init__The most straightforward approach is using conditional logic directly within your __init__ method.
import numpy as np
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.utils.validation import check_is_fitted
class CustomScaler(BaseEstimator, TransformerMixin):
"""
A custom scaler that scales features by a constant factor.
Parameters
----------
scale_factor : float, default=1.0
The factor to multiply features by. Must be positive.
strategy : {'multiply', 'divide'}, default='multiply'
Whether to multiply or divide by the scale_factor.
"""
def __init__(self, scale_factor=1.0, strategy='multiply'):
# Manual Type and Value Validation
if not isinstance(scale_factor, (int, float)):
raise TypeError(f"Parameter 'scale_factor' must be numeric, got {type(scale_factor).__name__}")
if scale_factor <= 0:
raise ValueError(f"Parameter 'scale_factor' must be positive, got {scale_factor}")
allowed_strategies = {'multiply', 'divide'}
if strategy not in allowed_strategies:
raise ValueError(f"Parameter 'strategy' must be one of {allowed_strategies}, got '{strategy}'")
self.scale_factor = scale_factor
self.strategy = strategy
def fit(self, X, y=None):
# No fitting necessary for this simple transformer
# Optional: Add input validation for X here using check_array
# from sklearn.utils.validation import check_array
# X = check_array(X)
self._n_features_in = X.shape[1]
return self
def transform(self, X):
check_is_fitted(self)
# Optional: Add input validation for X here
# X = check_array(X)
if X.shape[1] != self._n_features_in:
raise ValueError(f"Input has {X.shape[1]} features, expected {self._n_features_in}")
if self.strategy == 'multiply':
return X * self.scale_factor
elif self.strategy == 'divide':
# Add check for division by zero if scale_factor could be zero (already prevented by init validation)
return X / self.scale_factor
This is clear for simple cases but can become verbose and repetitive if you have many parameters with complex constraints.
validate_parameter_constraintsScikit-learn provides a more structured and declarative way to handle parameter validation using the validate_parameter_constraints decorator and a class-level dictionary named _parameter_constraints. This approach promotes consistency and reduces boilerplate code in __init__.
You define the constraints for each parameter in the _parameter_constraints dictionary. The keys are the parameter names, and the values are lists of allowed types or constraint objects (like Interval, StrOptions).
import numpy as np
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.utils.validation import check_is_fitted, check_array
# Import constraint classes and the decorator
from sklearn.utils._param_validation import validate_parameter_constraints, Interval, StrOptions
import numbers # Use numbers.Real for broader numeric types
class CustomScalerValidated(BaseEstimator, TransformerMixin):
"""
A custom scaler that scales features by a constant factor.
Uses validate_parameter_constraints for validation.
Parameters
----------
scale_factor : float, default=1.0
The factor to multiply features by. Must be positive.
strategy : {'multiply', 'divide'}, default='multiply'
Whether to multiply or divide by the scale_factor.
"""
# Define constraints using the special class attribute
_parameter_constraints: dict = {
"scale_factor": [Interval(numbers.Real, 0, None, closed="neither")], # Must be real, > 0
"strategy": [StrOptions({"multiply", "divide"})], # Must be one of these strings
}
# Apply the decorator to the class
@validate_parameter_constraints
def __init__(self, scale_factor=1.0, strategy='multiply'):
# No explicit validation code needed here!
# The decorator handles it based on _parameter_constraints.
self.scale_factor = scale_factor
self.strategy = strategy
def fit(self, X, y=None):
X = check_array(X, ensure_2d=True, dtype=np.float64) # Validate X input in fit
self._n_features_in = X.shape[1]
# Store any fitted attributes here (none needed for this simple scaler)
# self.mean_ = X.mean(axis=0) # Example if fitting was needed
return self
def transform(self, X):
check_is_fitted(self) # Check if fit has been called
X = check_array(X, ensure_2d=True, dtype=np.float64) # Validate X input in transform
if X.shape[1] != self._n_features_in:
raise ValueError(f"Input has {X.shape[1]} features, but scaler was fitted with {self._n_features_in} features.")
if self.strategy == 'multiply':
return X * self.scale_factor
elif self.strategy == 'divide':
# The constraint ensures scale_factor > 0, so no division by zero
return X / self.scale_factor
# Example of how validation works:
try:
scaler = CustomScalerValidated(scale_factor=-2.0)
except ValueError as e:
print(f"Validation Error: {e}")
# Expected output: Validation Error: The 'scale_factor' parameter of CustomScalerValidated must be a strictly positive real number. Got -2.0 instead.
try:
scaler = CustomScalerValidated(strategy='add')
except ValueError as e:
print(f"Validation Error: {e}")
# Expected output: Validation Error: The 'strategy' parameter of CustomScalerValidated must be a str among {'multiply', 'divide'}. Got 'add' instead.
The validate_parameter_constraints decorator inspects the arguments passed to __init__, compares them against the rules defined in _parameter_constraints, and raises informative TypeError or ValueError exceptions if they don't match. This is the recommended approach for modern Scikit-learn compatible components. Available constraints include Interval (for numerical ranges), StrOptions (for string options), Options (generic set membership), type hints (like int, float, list, callable, np.ndarray), and None. You can also combine constraints in the list (e.g., [StrOptions({"auto", "manual"}), None] to allow specific strings or None).
__init__, always store the parameters passed by the user directly as attributes with the same name (e.g., self.scale_factor = scale_factor). Scikit-learn's tools like get_params and set_params rely on this.get_params and set_params: Ensure your estimator inherits from BaseEstimator (or includes mixins that do). This provides default get_params and set_params methods essential for hyperparameter tuning (like GridSearchCV) and cloning. These methods work by inspecting the __init__ signature and accessing attributes with corresponding names.__init__ (ideally using validate_parameter_constraints).X (like shape compatibility) to the fit or transform methods. Use Scikit-learn's data validation utilities like check_array and check_X_y here.check_is_fitted at the beginning of transform, predict, etc., to ensure fit has been called.__init__ docstring, clearly stating its type, purpose, allowed range or options, and default value. This complements programmatic validation.By implementing thorough parameter validation and management, you create custom Scikit-learn components that are functionally correct, user-friendly, and well-integrated into the machine learning toolkit. Using validate_parameter_constraints is the preferred method for achieving this efficiently and consistently with Scikit-learn best practices.
Was this section helpful?
_parameter_constraints and validate_parameter_constraints, as well as general API conventions.check_array, check_X_y, and check_is_fitted, which are essential for validating input data and estimator state in custom components.© 2026 ApX Machine LearningEngineered with