趋近智
虽然使用函数的程序化脚本对于简单任务有效,但构建更复杂且更易于维护的机器学习系统常受益于面向对象编程(OOP)提供的结构。OOP 有助于将代码组织成逻辑单元,使项目随着增长更易于管理、复用和扩展。基本的 OOP 原则在机器学习流程中高度适用。
核心在于,OOP 围绕着 类 和 对象。
DatasetLoader 类,它指定数据集应如何加载,包括文件路径等属性以及 load_csv() 或 get_features() 等方法。DatasetLoader 对象,每个指向不同的文件路径,但都共享类中定义的相同加载逻辑。import pandas as pd
class SimpleDatasetLoader:
"""一个从CSV文件加载数据的基本类。"""
def __init__(self, filepath):
"""使用文件路径初始化加载器。"""
self.filepath = filepath
self.data = None # 用于存储已加载数据的属性
print(f"Loader initialized for: {self.filepath}")
def load_data(self):
"""将数据从CSV文件加载到Pandas DataFrame中。"""
try:
self.data = pd.read_csv(self.filepath)
print(f"Data loaded successfully with {self.data.shape[1]} columns.")
except FileNotFoundError:
print(f"Error: File not found at {self.filepath}")
self.data = None
except Exception as e:
print(f"An error occurred during loading: {e}")
self.data = None
def get_shape(self):
"""返回已加载数据的形状(如果可用)。"""
if self.data is not None:
return self.data.shape
else:
return "No data loaded."
# 创建类的对象(实例)
loader1 = SimpleDatasetLoader('data/train.csv')
loader2 = SimpleDatasetLoader('data/test.csv')
# 使用对象的方法
loader1.load_data()
print(f"Shape of dataset 1: {loader1.get_shape()}")
loader2.load_data()
print(f"Shape of dataset 2: {loader2.get_shape()}")
在此示例中,SimpleDatasetLoader 是类(蓝图)。loader1 和 loader2 是对象(实例),各自拥有 filepath 属性,但共享类中定义的 load_data 和 get_shape 方法。__init__ 方法是一个特殊的构造方法,在对象创建时运行。
封装意味着将数据(属性)和操作这些数据的方法捆绑到一个单元(类)中。它还涉及控制对对象内部状态的访问,通常称为数据隐藏。
在 Python 中,封装更多是基于约定而非严格强制。像单个下划线 (_) 这样的前缀表示属性或方法 intended for internal use (内部使用),而双下划线 (__) 则触发名称修饰(name mangling),使其难以(但并非不可能)从类外部直接访问。
封装在机器学习中的作用包括:
考虑一个用于特征缩放的类:
import numpy as np
class SimpleStandardScaler:
"""一个基本的标准化缩放器实现。"""
def __init__(self):
self._mean = None # 内部状态:均值
self._std_dev = None # 内部状态:标准差
def fit(self, X):
"""计算均值和标准差。"""
# 输入 X 预期为 NumPy 数组
if not isinstance(X, np.ndarray):
X = np.array(X)
self._mean = np.mean(X, axis=0)
self._std_dev = np.std(X, axis=0)
# 处理标准差为零的情况(常数特征)
self._std_dev[self._std_dev == 0] = 1.0
print("Scaler fitted.")
def transform(self, X):
"""使用计算出的均值和标准差应用缩放。"""
if self._mean is None or self._std_dev is None:
raise ValueError("Scaler has not been fitted yet.")
if not isinstance(X, np.ndarray):
X = np.array(X)
# 广播按列元素级应用缩放
return (X - self._mean) / self._std_dev
def fit_transform(self, X):
"""拟合缩放器,然后转换数据。"""
self.fit(X)
return self.transform(X)
# 使用示例
data = np.array([[1, 10], [2, 12], [3, 11], [4, 15]])
scaler = SimpleStandardScaler()
# 拟合缩放器到数据
scaler.fit(data)
# 转换新数据(或原始数据)
scaled_data = scaler.transform(data)
print("Scaled Data:\n", scaled_data)
# 访问内部状态(可能,但按惯例不鼓励)
# print(scaler._mean)
在此处,_mean 和 _std_dev 是由 fit 方法管理并由 transform 方法使用的内部状态。用户主要通过 fit、transform 和 fit_transform 进行交互。
继承允许新类(派生类或子类)从现有类(基类或父类)继承属性和方法。这促进代码复用并建立“是一种”(is-a)关系(例如,DecisionTreeModel 是 BaseModel 的一种)。
在机器学习中,继承被频繁使用:
BaseEstimator、TransformerMixin)。为了创建与库生态系统(如 Pipelines 或 GridSearch)兼容的自定义模型或转换器,你需要继承这些基类并实现所需的方法(fit、predict、transform)。Model 类,并创建 LinearRegressionModel 或 NeuralNetworkModel 等特化类,它们继承通用功能(如保存/加载),但以不同方式实现训练和预测。数据转换器的继承结构示例。
# 假设 BaseTransformer 在其他地方定义(例如在 scikit-learn 中)
# 为了说明,我们定义一个基类:
class BaseTransformer:
def fit(self, X, y=None):
# 默认实现:不执行任何操作
return self
def transform(self, X):
# 基类通常会引发 NotImplementedError
# 以强制子类实现基本方法
raise NotImplementedError("Subclasses must implement transform()")
def fit_transform(self, X, y=None):
self.fit(X, y)
return self.transform(X)
class SimpleMinMaxScaler(BaseTransformer): # 继承自 BaseTransformer
"""将特征缩放到 [0, 1] 范围。"""
def __init__(self):
self._min = None
self._range = None
def fit(self, X, y=None):
if not isinstance(X, np.ndarray):
X = np.array(X)
self._min = np.min(X, axis=0)
self._range = np.max(X, axis=0) - self._min
# 处理零范围(常数特征)
self._range[self._range == 0] = 1.0
print("MinMaxScaler fitted.")
return self # 对于链式调用/管道很重要
def transform(self, X):
if self._min is None or self._range is None:
raise ValueError("MinMaxScaler has not been fitted yet.")
if not isinstance(X, np.ndarray):
X = np.array(X)
return (X - self._min) / self._range
# 使用示例
min_max_scaler = SimpleMinMaxScaler()
data = np.array([[1, 10], [2, 12], [3, 11], [4, 15]])
scaled_data_minmax = min_max_scaler.fit_transform(data)
print("MinMax Scaled Data:\n", scaled_data_minmax)
在此处,SimpleMinMaxScaler 继承自 BaseTransformer。它提供了它自己的 fit 和 transform 特定实现,同时可能从基类中定义的方法(如此处所示的 fit_transform)中受益。
多态(“多种形式”)允许不同类的对象以各自特定的方式响应相同的方法调用。如果多个类继承自同一个基类并实现了一个方法(如 transform),你可以在这些派生类的任何对象上调用该方法,并且会执行正确的实现。
这是机器学习管道工作方式的根本。管道可能包含各种转换步骤(不同缩放器或编码器类的对象)。当你调用 pipeline.fit(data) 或 pipeline.transform(data) 时,管道会遍历其步骤,对每个对象调用 fit 或 transform 方法。多态确保在每个步骤应用适当的缩放、编码或插补逻辑,即使具体类不同。
# 使用之前定义的缩放器类
scaler_std = SimpleStandardScaler()
scaler_minmax = SimpleMinMaxScaler()
transformers = [scaler_std, scaler_minmax]
data_to_process = np.array([[50, 5], [60, 7], [70, 6]])
# 通过相同的接口使用不同的转换器处理数据
for i, transformer in enumerate(transformers):
print(f"\n--- Processing with Transformer {i+1} ({transformer.__class__.__name__}) ---")
# 使用通用接口进行拟合和转换
processed_data = transformer.fit_transform(data_to_process)
print("Processed Data:\n", processed_data)
在此循环中,scaler_std 和 scaler_minmax 对象都被视为 transformer。调用 transformer.fit_transform() 在第一次迭代中执行 SimpleStandardScaler 类中定义的该方法的特定版本,并在第二次迭代中执行 SimpleMinMaxScaler 类中的版本。
在开发机器学习系统时应用 OOP 原则具有多项优势:
尽管并非所有机器学习脚本都需要完全面向对象,但理解这些原则有助于你编写更具扩展性和可维护性的代码,特别是当你构建更精巧的模型和数据处理管道时。它也为理解和扩展许多流行的机器学习库提供了基础。
这部分内容有帮助吗?
© 2026 ApX Machine Learning用心打造