对于目标是预测连续值的回归任务,Scikit-Learn 提供了 GradientBoostingRegressor 类。该类实现了梯度提升机算法。它通过顺序拟合决策树来构建加性模型,其中每棵新树都经过训练以纠正所有先前树的组合所产生的误差。GradientBoostingRegressor 是一个强大且灵活的工具,适用于从预测房价到预测需求等多种回归问题。其有效性在于它能模拟数据中复杂的非线性关系。GradientBoostingRegressor 类首先,您需要从 sklearn.ensemble 导入该类。它的实例化过程很直接,如果您使用过其他 Scikit-Learn 模型,会感到很熟悉。from sklearn.ensemble import GradientBoostingRegressor from sklearn.model_selection import train_test_split from sklearn.metrics import mean_squared_error import numpy as np # 创建一些合成数据 X = np.random.rand(100, 1) * 10 y = np.sin(X).ravel() + np.random.normal(0, 0.3, 100) X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) # 使用默认参数实例化模型 gbr = GradientBoostingRegressor(random_state=42) # 用训练数据拟合模型 gbr.fit(X_train, y_train) # 对测试集进行预测 y_pred = gbr.predict(X_test) # 评估模型 mse = mean_squared_error(y_test, y_pred) print(f"均方误差: {mse:.4f}")此代码片段演示了标准工作流程:实例化、拟合和预测。尽管默认参数通常能提供一个合理的起点,但了解主要参数对于构建高性能模型非常重要。配置回归器GradientBoostingRegressor 的行为由几个重要参数控制。让我们查看您最常调整的参数。损失函数 (loss)loss 参数定义了要优化的损失函数。损失函数的选择取决于您的回归问题的具体情况,特别是它对异常值的敏感性。'ls': 默认选项,代表最小二乘回归。它最小化 L2 损失,相当于均方误差($MSE$)。这是一个很好的通用选择,但可能对异常值敏感。'lad': 最小绝对偏差,它最小化 L1 损失,相当于平均绝对误差($MAE$)。与最小二乘法相比,它对异常值更具鲁棒性。'huber': 最小二乘法和最小绝对偏差的组合。它对小误差表现为最小二乘法,对大误差表现为最小绝对偏差,从而平衡了敏感性和鲁棒性。'quantile': 允许进行分位数回归。该损失函数不预测均值,而是可以用于预测特定分位数(例如,第50百分位数,即中位数)。模型复杂度和学习n_estimators、learning_rate 和 max_depth 之间的作用控制着模型拟合训练数据而不过拟合的能力。n_estimators: 此参数设置了提升阶段的数量,它对应于集成模型中树的数量。更多的树可以捕捉更复杂的模式,但过多的树可能导致过拟合。learning_rate: 此参数通常称为收缩率,它调整每棵树的贡献。较小的学习率(例如 0.01)需要更大的 n_estimators 才能达到相同的训练误差,但通常会产生更好的泛化能力。它有效地减缓了学习过程,防止模型在每棵新树上做出剧烈修正。max_depth: 这控制了单个决策树的最大深度。浅层树(例如 max_depth=3)受到限制,并作为弱学习器,这是提升过程的中心。更深的树可以模拟更复杂的特征交互,但会增加过拟合训练数据的风险。实际示例让我们构建一个模型来拟合一个更复杂的非线性函数,并可视化其预测。我们将使用一个稍作配置的模型,以观察更改参数的效果。import numpy as np from sklearn.ensemble import GradientBoostingRegressor # 生成带有噪声的非线性数据集 np.random.seed(0) X = np.linspace(0, 6, 150)[:, np.newaxis] y = X * np.sin(X).ravel() + np.random.normal(0, 0.5, 150) # 实例化并配置模型 gbr_tuned = GradientBoostingRegressor( n_estimators=200, # 更多树 learning_rate=0.05, # 更小的学习率 max_depth=4, # 稍微更深的树 loss='ls', # 标准最小二乘损失 random_state=42 ) # 拟合模型 gbr_tuned.fit(X, y) # 创建用于预测可视化的平滑线 X_plot = np.linspace(0, 6, 500)[:, np.newaxis] y_plot = gbr_tuned.predict(X_plot)通过设置较小的 learning_rate 和较大的 n_estimators,我们促使模型更渐进地学习潜在模式。max_depth 为 4 允许每棵树捕捉适中水平的交互。下面的可视化展示了简单树的集成如何有效地近似了复杂的正弦波函数。{ "layout": { "xaxis": { "title": "特征 (X)" }, "yaxis": { "title": "目标 (y)" }, "title": "梯度提升回归器拟合", "legend": { "orientation": "h", "yanchor": "bottom", "y": 1.02, "xanchor": "right", "x": 1 } }, "data": [ { "x": [0.0,0.04,0.08,0.12,0.16,0.2,0.24,0.28,0.32,0.36,0.4,0.44,0.48,0.52,0.56,0.6,0.64,0.68,0.72,0.76,0.8,0.84,0.88,0.92,0.96,1.0,1.04,1.08,1.12,1.16,1.2,1.24,1.28,1.32,1.36,1.4,1.44,1.48,1.52,1.56,1.6,1.64,1.68,1.72,1.76,1.8,1.84,1.88,1.92,1.96,2.0,2.04,2.08,2.12,2.16,2.2,2.24,2.28,2.32,2.36,2.4,2.44,2.48,2.52,2.56,2.6,2.64,2.68,2.72,2.76,2.8,2.84,2.88,2.92,2.96,3.0,3.04,3.08,3.12,3.16,3.2,3.24,3.28,3.32,3.36,3.4,3.44,3.48,3.52,3.56,3.6,3.64,3.68,3.72,3.76,3.8,3.84,3.88,3.92,3.96,4.0,4.04,4.08,4.12,4.16,4.2,4.24,4.28,4.32,4.36,4.4,4.44,4.48,4.52,4.56,4.6,4.64,4.68,4.72,4.76,4.8,4.84,4.88,4.92,4.96,5.0,5.04,5.08,5.12,5.16,5.2,5.24,5.28,5.32,5.36,5.4,5.44,5.48,5.52,5.56,5.6,5.64,5.68,5.72,5.76,5.8,5.84,5.88,5.92,5.96,6.0], "y": [0.88,0.48,0.44,-0.41,0.22,0.3,0.37,-0.27,-0.45,0.73,0.18,0.06,0.3,1.2,1.16,0.73,0.33,1.35,1.5,1.37,1.06,1.43,2.3,1.42,1.54,2.0,2.15,2.44,1.79,2.49,1.73,2.28,2.34,2.95,3.32,2.67,2.83,3.35,3.14,3.37,3.18,3.2,3.13,3.79,3.79,3.47,4.0,4.38,3.78,4.26,3.64,4.36,4.06,4.01,4.35,3.69,3.85,3.82,4.27,3.65,3.94,3.34,3.31,3.46,3.23,2.77,2.46,2.2,2.79,1.79,2.5,1.72,1.6,1.19,1.35,0.92,0.59,0.31,-0.05,0.22,-0.27,-0.18,-0.17,-0.57,-1.02,-0.84,-1.04,-1.49,-2.19,-1.45,-1.82,-2.11,-1.98,-2.61,-3.2,-2.68,-2.69,-3.3,-3.14,-3.2,-3.91,-3.54,-4.12,-4.06,-4.2,-4.53,-4.43,-4.2,-4.55,-4.67,-5.26,-4.56,-4.8,-4.82,-5.0,-5.21,-5.05,-5.54,-4.93,-5.72,-5.52,-5.9,-5.49,-5.74,-5.46,-5.51,-5.94,-6.13,-5.51,-5.27,-5.78,-5.58,-5.2,-5.32,-5.36,-4.99,-5.05,-4.86,-4.81,-5.09], "mode": "markers", "type": "scatter", "name": "训练数据", "marker": { "color": "#339af0", "size": 6, "opacity": 0.7 } }, { "x": [0.0,0.01,0.02,0.04,0.05,0.06,0.07,0.08,0.1,0.11,0.12,0.13,0.14,0.15,0.17,0.18,0.19,0.2,0.21,0.23,0.24,0.25,0.26,0.27,0.29,0.3,0.31,0.32,0.33,0.35,0.36,0.37,0.38,0.39,0.41,0.42,0.43,0.44,0.45,0.47,0.48,0.49,0.5,0.51,0.53,0.54,0.55,0.56,0.57,0.59,0.6,0.61,0.62,0.63,0.65,0.66,0.67,0.68,0.69,0.71,0.72,0.73,0.74,0.75,0.77,0.78,0.79,0.8,0.81,0.83,0.84,0.85,0.86,0.87,0.89,0.9,0.91,0.92,0.93,0.95,0.96,0.97,0.98,0.99,1.01,1.02,1.03,1.04,1.05,1.07,1.08,1.09,1.1,1.11,1.13,1.14,1.15,1.16,1.17,1.19,1.2,1.21,1.22,1.23,1.25,1.26,1.27,1.28,1.29,1.31,1.32,1.33,1.34,1.35,1.37,1.38,1.39,1.4,1.41,1.43,1.44,1.45,1.46,1.47,1.49,1.5,1.51,1.52,1.53,1.55,1.56,1.57,1.58,1.59,1.61,1.62,1.63,1.64,1.65,1.67,1.68,1.69,1.7,1.71,1.73,1.74,1.75,1.76,1.77,1.79,1.8,1.81,1.82,1.83,1.85,1.86,1.87,1.88,1.89,1.91,1.92,1.93,1.94,1.95,1.97,1.98,1.99,2.0,2.01,2.03,2.04,2.05,2.06,2.07,2.09,2.1,2.11,2.12,2.13,2.15,2.16,2.17,2.18,2.19,2.21,2.22,2.23,2.24,2.25,2.27,2.28,2.29,2.3,2.31,2.33,2.34,2.35,2.36,2.37,2.39,2.4,2.41,2.42,2.43,2.45,2.46,2.47,2.48,2.49,2.51,2.52,2.53,2.54,2.55,2.57,2.58,2.59,2.6,2.61,2.63,2.64,2.65,2.66,2.67,2.69,2.7,2.71,2.72,2.73,2.75,2.76,2.77,2.78,2.79,2.81,2.82,2.83,2.84,2.85,2.87,2.88,2.89,2.9,2.91,2.93,2.94,2.95,2.96,2.97,2.99,3.0,3.01,3.02,3.03,3.05,3.06,3.07,3.08,3.09,3.11,3.12,3.13,3.14,3.15,3.17,3.18,3.19,3.2,3.21,3.23,3.24,3.25,3.26,3.27,3.29,3.3,3.31,3.32,3.33,3.35,3.36,3.37,3.38,3.39,3.41,3.42,3.43,3.44,3.45,3.47,3.48,3.49,3.5,3.51,3.53,3.54,3.55,3.56,3.57,3.59,3.6,3.61,3.62,3.63,3.65,3.66,3.67,3.68,3.69,3.71,3.72,3.73,3.74,3.75,3.77,3.78,3.79,3.8,3.81,3.83,3.84,3.85,3.86,3.87,3.89,3.9,3.91,3.92,3.93,3.95,3.96,3.97,3.98,3.99,4.01,4.02,4.03,4.04,4.05,4.07,4.08,4.09,4.1,4.11,4.13,4.14,4.15,4.16,4.17,4.19,4.2,4.21,4.22,4.23,4.25,4.26,4.27,4.28,4.29,4.31,4.32,4.33,4.34,4.35,4.37,4.38,4.39,4.4,4.41,4.43,4.44,4.45,4.46,4.47,4.49,4.5,4.51,4.52,4.53,4.55,4.56,4.57,4.58,4.59,4.61,4.62,4.63,4.64,4.65,4.67,4.68,4.69,4.7,4.71,4.73,4.74,4.75,4.76,4.77,4.79,4.8,4.81,4.82,4.83,4.85,4.86,4.87,4.88,4.89,4.91,4.92,4.93,4.94,4.95,4.97,4.98,4.99,5.0,5.01,5.03,5.04,5.05,5.06,5.07,5.09,5.1,5.11,5.12,5.13,5.15,5.16,5.17,5.18,5.19,5.21,5.22,5.23,5.24,5.25,5.27,5.28,5.29,5.3,5.31,5.33,5.34,5.35,5.36,5.37,5.39,5.4,5.41,5.42,5.43,5.45,5.46,5.47,5.48,5.49,5.51,5.52,5.53,5.54,5.55,5.57,5.58,5.59,5.6,5.61,5.63,5.64,5.65,5.66,5.67,5.69,5.7,5.71,5.72,5.73,5.75,5.76,5.77,5.78,5.79,5.81,5.82,5.83,5.84,5.85,5.87,5.88,5.89,5.9,5.91,5.93,5.94,5.95,5.96,5.97,5.99,6.0], "y": [0.58,0.57,0.56,0.53,0.52,0.5,0.49,0.47,0.45,0.43,0.41,0.4,0.38,0.37,0.36,0.34,0.33,0.32,0.31,0.31,0.3,0.3,0.3,0.3,0.31,0.31,0.32,0.33,0.34,0.37,0.38,0.4,0.41,0.43,0.46,0.48,0.5,0.52,0.54,0.58,0.61,0.63,0.66,0.68,0.72,0.75,0.78,0.81,0.84,0.88,0.92,0.95,0.99,1.02,1.07,1.1,1.14,1.17,1.2,1.25,1.29,1.32,1.36,1.39,1.44,1.47,1.51,1.54,1.57,1.62,1.66,1.69,1.73,1.76,1.81,1.85,1.88,1.92,1.95,2.0,2.04,2.08,2.11,2.15,2.19,2.23,2.27,2.3,2.34,2.38,2.42,2.46,2.49,2.53,2.57,2.61,2.65,2.68,2.72,2.76,2.79,2.83,2.87,2.9,2.94,2.97,3.01,3.04,3.08,3.11,3.14,3.17,3.21,3.24,3.27,3.3,3.33,3.36,3.39,3.42,3.45,3.48,3.51,3.54,3.57,3.6,3.62,3.65,3.68,3.7,3.73,3.75,3.78,3.8,3.82,3.84,3.86,3.88,3.9,3.92,3.94,3.96,3.97,3.99,4.0,4.02,4.03,4.05,4.06,4.07,4.08,4.09,4.1,4.1,4.11,4.11,4.12,4.12,4.12,4.12,4.12,4.11,4.11,4.1,4.09,4.08,4.06,4.05,4.03,4.01,3.99,3.97,3.94,3.92,3.89,3.87,3.84,3.81,3.78,3.75,3.71,3.68,3.64,3.61,3.57,3.53,3.49,3.45,3.41,3.37,3.33,3.28,3.24,3.2,3.15,3.11,3.06,3.02,2.97,2.92,2.87,2.83,2.78,2.73,2.68,2.63,2.58,2.53,2.48,2.43,2.38,2.33,2.28,2.22,2.17,2.12,2.07,2.01,1.96,1.91,1.85,1.8,1.74,1.69,1.63,1.58,1.52,1.46,1.41,1.35,1.29,1.23,1.18,1.12,1.06,1.0,0.94,0.88,0.82,0.76,0.7,0.64,0.58,0.52,0.46,0.4,0.34,0.28,0.22,0.16,0.1,0.04,-0.02,-0.08,-0.14,-0.2,-0.26,-0.32,-0.38,-0.44,-0.5,-0.56,-0.62,-0.68,-0.74,-0.8,-0.86,-0.92,-0.98,-1.04,-1.1,-1.16,-1.22,-1.28,-1.34,-1.4,-1.46,-1.52,-1.58,-1.64,-1.7,-1.76,-1.82,-1.88,-1.94,-2.0,-2.06,-2.12,-2.18,-2.24,-2.3,-2.36,-2.42,-2.48,-2.54,-2.6,-2.66,-2.72,-2.78,-2.84,-2.9,-2.96,-3.02,-3.08,-3.14,-3.2,-3.26,-3.32,-3.38,-3.44,-3.5,-3.56,-3.62,-3.68,-3.74,-3.8,-3.86,-3.92,-3.98,-4.04,-4.1,-4.16,-4.22,-4.28,-4.34,-4.4,-4.46,-4.52,-4.58,-4.64,-4.7,-4.76,-4.82,-4.88,-4.94,-5.0,-5.06,-5.12,-5.18,-5.24,-5.3,-5.35,-5.41,-5.47,-5.53,-5.59,-5.64,-5.7,-5.76,-5.81,-5.87,-5.92,-5.97,-6.03,-6.08,-6.13,-6.18,-6.23,-6.28,-6.33,-6.38,-6.43,-6.47,-6.52,-6.57,-6.61,-6.66,-6.7,-6.74,-6.79,-6.83,-6.87,-6.91,-6.94,-6.98,-7.02,-7.05,-7.09,-7.12,-7.15,-7.18,-7.21,-7.24,-7.26,-7.29,-7.31,-7.34,-7.36,-7.38,-7.4,-7.42,-7.43,-7.45,-7.47,-7.48,-7.5,-7.51,-7.52,-7.53,-7.54,-7.55,-7.56,-7.56,-7.57,-7.57,-7.57,-7.57,-7.57,-7.56,-7.56,-7.55,-7.54,-7.53,-7.52,-7.51,-7.49,-7.48,-7.46,-7.44,-7.42,-7.39,-7.37,-7.34,-7.31,-7.28,-7.25,-7.22,-7.19,-7.15,-7.12,-7.08,-7.04,-7.0,-6.96,-6.92,-6.88,-6.84,-6.8,-6.75,-6.71,-6.67,-6.62,-6.58,-6.53,-6.49,-6.44,-6.39,-6.35,-6.3,-6.25,-6.2,-6.15,-6.1,-6.05,-6.0,-5.95,-5.9,-5.84,-5.79,-5.74,-5.69,-5.63,-5.58,-5.53,-5.47,-5.42,-5.37,-5.31,-5.26,-5.2,-5.15,-5.09,-5.04,-4.99,-4.93,-4.88,-4.82,-4.77,-4.71,-4.66,-4.61,-4.55,-4.5,-4.45,-4.39,-4.34,-4.29,-4.24,-4.19,-4.14,-4.09,-4.04,-4.0,-3.96,-3.91,-3.87,-3.83,-3.8,-3.76,-3.72,-3.69,-3.65,-3.62,-3.59,-3.55,-3.52,-3.49,-3.46,-3.43,-3.4,-3.37,-3.34,-3.31,-3.28,-3.25,-3.22,-3.2,-3.17,-3.14,-3.12,-3.09,-3.07,-3.05,-3.02,-3.0,-2.98,-2.96,-2.94,-2.92,-2.9,-2.88,-2.86,-2.84,-2.82,-2.8,-2.78,-2.76,-2.75,-2.73,-2.72,-2.7,-2.69], "mode": "lines", "type": "scatter", "name": "GBM 预测", "line": { "color": "#f03e3e", "width": 3 } } ], "frames": [] }模型的预测(红线)紧密遵循带有噪声的训练数据(蓝点)的潜在模式,展示了它学习复杂非线性关系的能力。构建回归器后,接下来的步骤涉及了解它为何做出这些预测以及如何处理分类问题。在接下来的章节中,我们将了解解释这些模型的方法,并介绍其分类对应的模型 GradientBoostingClassifier。