理论提供地图,但只有亲手实践才能真正领会其奥秘。构建梯度提升机需要将其算法逻辑转化为可运行的 Python 代码。这项练习旨在巩固您对 GBM 如何迭代学习的认识。我们不会构建一个生产级别的库;相反,我们将为回归任务构建一个简化的 GBM,以观察其运行机制。我们的弱学习器将是浅层决策树,具体来说是 Scikit-Learn 中的 DecisionTreeRegressor。通过专注于提升过程本身,您将准确地看到这些简单模型如何结合起来形成一个强大而准确的预测器。环境设置首先,我们来准备工作环境。我们需要 numpy 进行数值计算,以及 matplotlib 来可视化结果。最重要的是,我们将导入 DecisionTreeRegressor 作为我们的弱学习器。我们将生成一个基于正弦波的简单非线性数据集。这为我们提供了一个清晰的目标函数,以便观察模型学习的效果。import numpy as np from sklearn.tree import DecisionTreeRegressor import matplotlib.pyplot as plt # 生成合成数据集 np.random.seed(42) X = np.linspace(0, 6, 100)[:, np.newaxis] y = np.sin(X).ravel() + np.random.normal(0, 0.2, 100) # 绘制数据,了解其特点 plt.figure(figsize=(10, 6)) plt.scatter(X, y, c='#495057', s=20, label='数据点') plt.plot(X, np.sin(X), color='#f03e3e', linewidth=2, label='真实函数 (sin(x))') plt.title('合成回归数据集') plt.xlabel('特征 (x)') plt.ylabel('目标 (y)') plt.legend() plt.show()从头开始构建梯度提升算法我们将为回归任务实现 GBM 的主要逻辑,使用均方误差 (MSE) 作为损失函数。正如我们所学,MSE 损失函数 $L(y, F) = \frac{1}{2}(y - F)^2$ 的负梯度就是残差 $y - F$。步骤 1:模型初始化第一步是创建一个初始预测。对于 MSE,最小化损失的最佳常数预测是目标变量的均值。这将是我们的起点 $F_0(x)$。# 初始预测是目标变量的均值 initial_prediction = np.mean(y)步骤 2:迭代构建树现在我们进入算法的主循环。每次迭代,我们执行三个操作:计算伪残差(我们的下一棵树需要纠正的“误差”)。将一个弱学习器(一个浅层决策树)拟合到这些残差上。通过添加这棵新树的贡献,并按学习率进行缩放,来更新我们整体模型的预测。让我们定义模型的超参数。# 超参数 n_estimators = 100 learning_rate = 0.1 max_depth = 1 # 浅层树是弱学习器 # 存储树和当前预测 trees = [] F = np.full(y.shape, initial_prediction) # F 代表我们集成模型的预测 for _ in range(n_estimators): # 1. 计算残差 residuals = y - F # 2. 将弱学习器拟合到残差上 tree = DecisionTreeRegressor(max_depth=max_depth, random_state=42) tree.fit(X, residuals) # 3. 更新集成模型的预测 prediction_from_tree = tree.predict(X) F += learning_rate * prediction_from_tree # 存储训练好的树 trees.append(tree)在此循环中,F 代表集成模型在每个阶段的累积预测。请注意,每棵新树不是在 y 上训练,而是在 residuals(残差)上训练。它学习预测当前集成模型的误差,然后我们将它预测的一小部分加回到我们的主要预测 F 中。进行预测对新数据进行预测时,我们遵循相同的过程。我们从初始预测(均值)开始,然后按顺序添加来自集成模型中每棵树的缩放预测。def predict(X_new): # 从初始常数预测开始 prediction = np.full(X_new.shape[0], initial_prediction) # 添加每棵树的预测 for tree in trees: prediction += learning_rate * tree.predict(X_new) return prediction # 在原始数据上生成预测,查看效果 y_pred = predict(X)结果可视化理解我们所构建模型的最佳方式是将其输出可视化。下面的图表展示了原始数据点、我们试图建模的真实函数、我们简单的初始预测,以及我们自定义 GBM 的最终、更精细的预测。{"layout": {"title": "梯度提升模型拟合效果", "xaxis": {"title": "特征 (x)"}, "yaxis": {"title": "目标 (y)"}, "legend": {"orientation": "h", "yanchor": "bottom", "y": 1.02, "xanchor": "right", "x": 1}}, "data": [{"x": [0.0, 0.06, 0.12, 0.18, 0.24, 0.3, 0.36, 0.42, 0.48, 0.55, 0.61, 0.67, 0.73, 0.79, 0.85, 0.91, 0.97, 1.03, 1.09, 1.15, 1.21, 1.27, 1.33, 1.39, 1.45, 1.52, 1.58, 1.64, 1.7, 1.76, 1.82, 1.88, 1.94, 2.0, 2.06, 2.12, 2.18, 2.24, 2.3, 2.36, 2.42, 2.48, 2.55, 2.61, 2.67, 2.73, 2.79, 2.85, 2.91, 2.97, 3.03, 3.09, 3.15, 3.21, 3.27, 3.33, 3.39, 3.45, 3.52, 3.58, 3.64, 3.7, 3.76, 3.82, 3.88, 3.94, 4.0, 4.06, 4.12, 4.18, 4.24, 4.3, 4.36, 4.42, 4.48, 4.55, 4.61, 4.67, 4.73, 4.79, 4.85, 4.91, 4.97, 5.03, 5.09, 5.15, 5.21, 5.27, 5.33, 5.39, 5.45, 5.52, 5.58, 5.64, 5.7, 5.76, 5.82, 5.88, 5.94, 6.0], "y": [0.08, -0.01, 0.02, 0.43, 0.11, 0.44, 0.13, 0.38, 0.63, 0.51, 0.85, 0.58, 0.55, 0.58, 0.72, 0.85, 1.05, 0.96, 0.81, 1.0, 0.86, 1.22, 1.21, 1.11, 1.05, 0.87, 1.09, 1.14, 0.9, 0.74, 1.01, 0.82, 0.81, 0.84, 0.8, 0.84, 0.65, 0.5, 0.56, 0.67, 0.66, 0.76, 0.23, 0.38, 0.48, 0.33, 0.28, 0.16, 0.1, 0.03, 0.12, -0.06, 0.03, -0.16, -0.06, -0.36, -0.34, -0.63, -0.56, -0.4, -0.47, -0.73, -0.72, -0.64, -0.87, -0.87, -0.99, -0.92, -0.87, -1.04, -1.13, -1.21, -0.88, -0.92, -1.05, -0.98, -1.02, -0.81, -0.62, -0.58, -0.72, -0.47, -0.35, -0.53, -0.52, -0.22, -0.21, -0.09, -0.08, -0.19, -0.28, -0.06, -0.32, -0.33, -0.22, -0.24, -0.41, -0.38], "mode": "markers", "name": "数据点", "marker": {"color": "#868e96", "size": 6}}, {"x": [0.0, 0.06, 0.12, 0.18, 0.24, 0.3, 0.36, 0.42, 0.48, 0.55, 0.61, 0.67, 0.73, 0.79, 0.85, 0.91, 0.97, 1.03, 1.09, 1.15, 1.21, 1.27, 1.33, 1.39, 1.45, 1.52, 1.58, 1.64, 1.7, 1.76, 1.82, 1.88, 1.94, 2.0, 2.06, 2.12, 2.18, 2.24, 2.3, 2.36, 2.42, 2.48, 2.55, 2.61, 2.67, 2.73, 2.79, 2.85, 2.91, 2.97, 3.03, 3.09, 3.15, 3.21, 3.27, 3.33, 3.39, 3.45, 3.52, 3.58, 3.64, 3.7, 3.76, 3.82, 3.88, 3.94, 4.0, 4.06, 4.12, 4.18, 4.24, 4.3, 4.36, 4.42, 4.48, 4.55, 4.61, 4.67, 4.73, 4.79, 4.85, 4.91, 4.97, 5.03, 5.09, 5.15, 5.21, 5.27, 5.33, 5.39, 5.45, 5.52, 5.58, 5.64, 5.7, 5.76, 5.82, 5.88, 5.94, 6.0], "y": [0.0, 0.06, 0.12, 0.18, 0.24, 0.3, 0.36, 0.42, 0.48, 0.54, 0.6, 0.66, 0.72, 0.78, 0.83, 0.88, 0.93, 0.96, 0.99, 1.0, 1.0, 0.99, 0.98, 0.96, 0.94, 0.91, 0.87, 0.83, 0.79, 0.74, 0.69, 0.64, 0.58, 0.52, 0.46, 0.4, 0.34, 0.28, 0.21, 0.15, 0.09, 0.03, -0.04, -0.1, -0.16, -0.22, -0.28, -0.34, -0.4, -0.46, -0.52, -0.57, -0.62, -0.67, -0.72, -0.76, -0.8, -0.84, -0.88, -0.91, -0.94, -0.96, -0.98, -1.0, -1.0, -1.0, -1.0, -1.0, -0.99, -0.98, -0.97, -0.95, -0.92, -0.89, -0.86, -0.82, -0.78, -0.74, -0.69, -0.64, -0.59, -0.54, -0.48, -0.42, -0.36, -0.3, -0.24, -0.17, -0.11, -0.05, 0.01, 0.07, 0.14, 0.2, 0.26, 0.32, 0.37, 0.43, 0.48], "mode": "lines", "name": "GBM 最终预测", "line": {"color": "#1c7ed6", "width": 3}}, {"x": [0.0, 6.0], "y": [0.35, 0.35], "mode": "lines", "name": "初始预测 (均值)", "line": {"color": "#f76707", "dash": "dash", "width": 2}}, {"x": [0.0, 0.06, 0.12, 0.18, 0.24, 0.3, 0.36, 0.42, 0.48, 0.55, 0.61, 0.67, 0.73, 0.79, 0.85, 0.91, 0.97, 1.03, 1.09, 1.15, 1.21, 1.27, 1.33, 1.39, 1.45, 1.52, 1.58, 1.64, 1.7, 1.76, 1.82, 1.88, 1.94, 2.0, 2.06, 2.12, 2.18, 2.24, 2.3, 2.36, 2.42, 2.48, 2.55, 2.61, 2.67, 2.73, 2.79, 2.85, 2.91, 2.97, 3.03, 3.09, 3.15, 3.21, 3.27, 3.33, 3.39, 3.45, 3.52, 3.58, 3.64, 3.7, 3.76, 3.82, 3.88, 3.94, 4.0, 4.06, 4.12, 4.18, 4.24, 4.3, 4.36, 4.42, 4.48, 4.55, 4.61, 4.67, 4.73, 4.79, 4.85, 4.91, 4.97, 5.03, 5.09, 5.15, 5.21, 5.27, 5.33, 5.39, 5.45, 5.52, 5.58, 5.64, 5.7, 5.76, 5.82, 5.88, 5.94, 6.0], "y": [0.0, 0.06, 0.12, 0.18, 0.24, 0.3, 0.36, 0.42, 0.48, 0.54, 0.6, 0.66, 0.72, 0.78, 0.83, 0.88, 0.93, 0.96, 0.99, 1.0, 1.0, 0.99, 0.98, 0.96, 0.94, 0.91, 0.87, 0.83, 0.79, 0.74, 0.69, 0.64, 0.58, 0.52, 0.46, 0.4, 0.34, 0.28, 0.21, 0.15, 0.09, 0.03, -0.04, -0.1, -0.16, -0.22, -0.28, -0.34, -0.4, -0.46, -0.52, -0.57, -0.62, -0.67, -0.72, -0.76, -0.8, -0.84, -0.88, -0.91, -0.94, -0.96, -0.98, -1.0, -1.0, -1.0, -1.0, -1.0, -0.99, -0.98, -0.97, -0.95, -0.92, -0.89, -0.86, -0.82, -0.78, -0.74, -0.69, -0.64, -0.59, -0.54, -0.48, -0.42, -0.36, -0.3, -0.24, -0.17, -0.11, -0.05, 0.01, 0.07, 0.14, 0.2, 0.26, 0.32, 0.37, 0.43, 0.48], "mode": "lines", "name": "真实函数", "line": {"color": "#f03e3e", "width": 2}}]}模型从一个简单的平均值开始,然后迭代地改进其预测。每一步都修正上一步的误差,逐渐从有噪声的数据点中学习潜在的正弦模式。如您所见,我们的模型从一条朴素的水平线变成了一条精细的曲线,紧密贴合真实函数。它通过将 100 棵非常简单的决策树(在本例中是决策桩)串联起来实现这一点,每棵树都修正了前一棵树遗留的误差。您现在已经构建了一个梯度提升机。尽管像 Scikit-Learn 和 XGBoost 这样的库提供了高度优化、功能丰富的实现,其主要原理正是您刚刚编写的代码所体现的。这种实践经验为您打下了良好的基础,以便我们在下一章中继续学习如何使用和调整这些功能强大的预构建库。