使用Scikit-learn实现线性回归涉及其统一且直接的接口。Scikit-learn提供了各种机器学习模型,包括线性回归,可以轻松应用。实现此功能的主要工具是sklearn.linear_model模块中的LinearRegression估计器。Scikit-learn估计器APIScikit-learn的设计围绕着“估计器”的理念。估计器是任何从数据中学习的对象;它可以是分类、回归或聚类算法,也可以是提取有用特征的转换器。所有估计器都遵循统一的模式:导入: 导入您需要的特定估计器类。实例化: 创建估计器类的实例,并可能配置其超参数(尽管LinearRegression的超参数不多)。准备数据: 将数据组织成特征矩阵X和目标向量y。X通常是二维NumPy数组或Pandas DataFrame(形状:[n_样本数, n_特征数]),而y是一维NumPy数组或Pandas Series(形状:[n_样本数])。拟合: 使用.fit(X, y)方法在您的数据上训练估计器。这是模型从数据中学习的步骤。对于线性回归,fit通过最小化数据集中观测目标与线性近似预测目标之间的残差平方和来计算最优系数($\beta_1, ..., \beta_p$)和截距($\beta_0$)。预测: 一旦拟合完成,您可以使用.predict(X_new)方法对新的、未见过的数据进行预测。X_new应与训练数据X具有相同数量的特征。实现LinearRegression我们来看一个简单的例子。我们将生成一些大致遵循线性模式的合成数据,然后用LinearRegression模型对其进行拟合。import numpy as np import matplotlib.pyplot as plt from sklearn.linear_model import LinearRegression # 1. 生成一些样本数据 # 使结果可重现 np.random.seed(42) X = 2 * np.random.rand(100, 1) # 特征(Scikit-learn要求为二维) y = 4 + 3 * X + np.random.randn(100, 1) # 带有噪声的目标变量 # 2. 导入估计器 # 上面已完成:from sklearn.linear_model import LinearRegression # 3. 实例化估计器 model = LinearRegression() # 4. 将模型拟合到数据 # Scikit-learn需要X(特征)和y(目标) model.fit(X, y) # 5. 检查学习到的参数 # 截距(beta_0)存储在.intercept_中 # 系数(beta_1, ..., beta_p)存储在.coef_中 print(f"截距 (beta_0): {model.intercept_[0]:.4f}") print(f"系数 (beta_1): {model.coef_[0][0]:.4f}") # 6. 对新数据进行预测 # 让我们预测X = 0和X = 2的值 X_new = np.array([[0], [2]]) # 新数据点必须是二维数组 y_pred = model.predict(X_new) print(f"\n对 X_new = [[0], [2]] 的预测:") print(y_pred) # 可选:可视化结果 plt.figure(figsize=(8, 5)) plt.scatter(X, y, alpha=0.7, label='原始数据') plt.plot(X_new, y_pred, "r-", linewidth=2, label='拟合回归线') plt.xlabel("特征 (X)") plt.ylabel("目标 (y)") plt.title("线性回归拟合") plt.legend() plt.grid(True) plt.show()运行此代码将:创建y近似为$4 + 3x$的样本数据。实例化LinearRegression模型。调用model.fit(X, y)。在此步骤中,Scikit-learn应用普通最小二乘法,通过最小化实际y值与直线预测值($\hat{y}$)之间的平方差之和,寻找最能拟合数据点的直线。打印学习到的截距(model.intercept_)和系数(model.coef_)。您应该会看到接近用于生成数据的原始参数(4和3)的值。轻微的差异是由于我们添加的随机噪声。使用model.predict()计算两个新输入值(0和2)的预测y值。生成一个图表,显示原始数据点和模型学习到的回归线。{"layout": {"title": "线性回归拟合", "xaxis": {"title": "特征 (X)"}, "yaxis": {"title": "目标 (y)"}, "showlegend": true, "grid": {"rows": 1, "columns": 1}, "autosize": true, "margin": {"l": 50, "r": 50, "t": 50, "b": 50}}, "data": [{"type": "scatter", "x": [1.4913896336947083, 1.2287806659036967, 1.867557990124384, 1.7366358157060448, 0.303838947344685, 0.8319939215351783, 0.6733164720260168, 0.3261375637360841, 1.160716659978909, 0.6848306442484165, 1.8527744928234158, 1.971386193526347, 0.8173032270165033, 0.00978737984948833, 1.702471510404726, 0.8462998467654085, 1.0347787920058724, 1.389553430611044, 1.6096992696104913, 1.9460497670983432, 1.087873771995703, 1.754710932784718, 1.055922622273966, 0.07507163816802328, 0.5052084998044065, 1.1348642993949898, 1.148809460319065, 1.8475167316745675, 1.430070031065558, 1.119958810161951, 1.8930914364525963, 1.101395354214376, 0.04577996130839161, 1.1371339309229812, 0.5631967820272609, 1.620986476669398, 1.592431310993379, 1.0937113200877773, 1.3119890409440128, 1.4136194013915566, 1.2683376610448663, 1.869287417478216, 1.160485180738888, 1.0822858928558458, 0.2894874470496819, 1.6359747325295773, 1.953237193957367, 0.6748294086250754, 1.1607947622827232, 0.877446042526646, 1.6484063883278586, 0.8992634545627223, 1.8615686331089142, 1.21237901710511, 0.844368877007823, 1.9697422171142833, 0.9189885932668495, 0.7889988408028044, 0.876702748607472, 0.273346604078303, 1.9799779677566896, 0.6719115644786217, 1.176096892980975, 1.3629137638107756, 0.4207978784946181, 1.027467116868123, 1.9754569510130414, 0.4288744464341383, 1.6693029396127663, 0.4609362452971535, 1.0188632096598157, 1.396784549947949, 1.3110360103344115, 0.05146931214236033, 1.6312421628873374, 1.5387605816702696, 0.8109820359988267, 1.0386752491828258, 0.841571767189044, 0.44379067256077337, 0.6189089984087399, 1.217697340949338, 1.1420765251327726, 1.0880908624719404, 1.6347046566205523, 1.3659071003199805, 1.3727542551305033, 1.2308532279740132, 1.3543933980246513, 0.9392911998227266, 1.6283573494624835, 0.5797419581691019, 1.3601563452697104, 0.7821300018771119, 1.099440274176414, 0.9739691548684825, 1.879519478848032], "y": [8.636815640529603, 8.416076303126603, 10.03612803376344, 8.41284976226774, 3.861781148475895, 5.980138099104206, 5.50843343547301, 6.378540722960926, 7.588868349067896, 5.711658831725319, 10.204606182590068, 9.19934697686196, 7.590556470710711, 3.440580526674055, 7.920269185267769, 6.623308193171674, 6.779915970584489, 8.186994906479525, 7.847523647105997, 10.433129627272738, 6.219346332697169, 8.831603282541667, 7.100201613312645, 4.124259265788148, 5.571948294536641, 7.875897063223285, 7.794066364641969, 9.649289174525677, 8.271319869102802, 8.20688181367149, 9.000228725364163, 6.7359675595673355, 4.139412720765576, 6.921223469178533, 5.951280121047464, 8.387028146856003, 8.475726812356734, 6.508981733937883, 8.13121689176657, 8.024461022161356, 8.239761569692563, 9.966747838843196, 7.249703394974396, 6.510182778547512, 5.44395343392221, 9.353919031080786, 8.612812226202976, 5.848429726075933, 7.598011164030997, 7.105895307832616, 8.42715617290781, 5.554845306678263, 8.875198777667955, 7.751544179236189, 6.978463754668253, 8.936148284280494, 7.512814891173334, 5.581336601536404, 6.931581004412185, 5.239452084458548, 10.656848188943784, 4.494161048197157, 7.641769173948418, 8.098706341593756, 5.625077868120547, 6.255181618841693, 9.73447262765327, 5.187539772819346, 9.611500896009995, 4.51078606555174, 6.91079165841608, 7.959109199219761, 7.094819837811965, 4.346168199158728, 8.78367686251618, 8.31740268128084, 7.050503750631387, 7.22263750139951, 6.636657636847474, 5.503816393711601, 5.920838028257958, 7.888403855315137, 7.47626796014309, 7.027632988041359, 8.58230955616567, 7.87054607517667, 7.973436319700759, 7.582155035882071, 7.697920739663727, 6.266469575268374, 8.94261354654102, 5.559965988471373, 7.09576979761091, 6.649080728837731, 7.120105817929579, 7.47439041879625, 10.54255164818641], "mode": "markers", "marker": {"color": "#339af0", "opacity": 0.7}, "name": "原始数据"}, {"type": "scatter", "x": [0, 2], "y": [3.939190416917814, 9.903985864215764], "mode": "lines", "line": {"color": "#f03e3e", "width": 2}, "name": "拟合回归线"}]}生成数据点的散点图以及Scikit-learn拟合的线性回归线。处理多个特征上述过程对于多元线性回归(即您有多个输入特征的情况)以完全相同的方式运行。唯一的区别是您的X矩阵将有多于一列(每个特征一列)。Scikit-learn的LinearRegression会自动处理此情况。# 带有2个特征的例子 X_multi = 2 * np.random.rand(100, 2) # 现在X有2列 # y = 4 + 3*X_1 + 5*X_2 + 噪声 y_multi = 4 + 3 * X_multi[:, 0] + 5 * X_multi[:, 1] + np.random.randn(100) y_multi = y_multi.reshape(-1, 1) # 如果需要,将y变为列向量 multi_model = LinearRegression() multi_model.fit(X_multi, y_multi) print(f"\n多元回归:") print(f"截距 (beta_0): {multi_model.intercept_[0]:.4f}") print(f"系数 (beta_1, beta_2): {multi_model.coef_[0]}") # 预测需要具有2个特征的输入 X_multi_new = np.array([[0, 0], [2, 3]]) # 预测 [X1=0, X2=0] 和 [X1=2, X2=3] y_multi_pred = multi_model.predict(X_multi_new) print(f"\n对 X_multi_new 的预测:") print(y_multi_pred).coef_属性现在将包含一个具有多个值的数组,X_multi中每个特征列对应一个系数。解释方式类似:每个系数表示当对应特征变化一个单位时,目标变量y的变化量,前提是所有其他特征保持不变。这种统一的API使得应用线性回归变得简单,无论您是有一个特征还是多个特征。模型拟合完成后,接下来重要的步骤包括理解学习到的系数的含义,以及评估模型的实际表现,这些内容我们将在后续章节中介绍。