线性模型,如线性回归或逻辑回归,功能强大且易于理解,但它们天生假定特征与目标变量之间存在线性关系。当这种关系不是直线时会发生什么?一种扩展这些模型以捕捉非线性模式的方法是创建多项式特征。本质上,多项式特征是通过将现有数值特征提升到一定幂次(如 $x^2$, $x^3$)或将特征相乘(交互项,如 $x_1 x_2$)而得到的新特征。通过将这些非线性项添加到数据集中,线性模型可以学习曲线关系。考虑一个包含一个特征 $x$ 的简单数据集。如果与目标 $y$ 的真实关系是二次的,例如 $y \approx ax^2 + bx + c$,那么拟合 $y \approx wx + b$ 的标准线性模型将表现不佳。然而,如果我创建一个新特征 $x^2$ 并使用 $x$ 和 $x^2$ 拟合一个线性模型,则模型变为 $y \approx w_1 x + w_2 x^2 + b$。就系数 ($w_1, w_2, b$) 而言,这仍然是一个线性模型,但它现在可以对原始特征 $x$ 与目标 $y$ 之间的二次关系进行建模。使用 Scikit-learn 生成多项式特征Scikit-learn 在其 preprocessing 模块中提供了一个便捷的转换器 PolynomialFeatures,可以自动生成这些特征。让我们看看它的实际运用。假设我们有一个包含两个特征 f1 和 f2 的简单数据集:import numpy as np from sklearn.preprocessing import PolynomialFeatures # 样本数据:3个样本,2个特征 X = np.array([[2, 3], [4, 1], [0, 5]]) # 初始化用于2次多项式的PolynomialFeatures转换器 # include_bias=False 移除常数项(由1组成的列) poly = PolynomialFeatures(degree=2, include_bias=False) # 拟合并转换数据 X_poly = poly.fit_transform(X) print("原始特征:\n", X) print("\n多项式特征(degree=2):\n", X_poly) print("\n特征名称:", poly.get_feature_names_out(['f1', 'f2']))输出将是:Original features: [[2 3] [4 1] [0 5]] Polynomial features (degree=2): [[ 2. 3. 4. 6. 9.] # f1, f2, f1^2, f1*f2, f2^2 [ 4. 1. 16. 4. 1.] [ 0. 5. 0. 0. 25.]] Feature names: ['f1' 'f2' 'f1^2' 'f1 f2' 'f2^2']正如所见,PolynomialFeatures(degree=2) 生成了原始特征 (f1, f2)、平方项 (f1^2, f2^2) 和交互项 (f1*f2)。PolynomialFeatures 的主要参数有:degree: 最大多项式特征的次数。次数为2会生成高达 $x^2$, $x_1 x_2$ 的项;次数为3会生成高达 $x^3$, $x_1^2 x_2$, $x_1 x_2^2$ 等的项。interaction_only: 如果设置为 True,则只生成交互特征(不同特征的乘积,如 $x_1 x_2$),而不生成单个特征的高次项(如 $x_1^2$)。默认为 False。include_bias: 如果设置为 True(默认值),它会包含一个偏置列(只包含1的特征)。这对于线性模型通常很有用,但如果后续的估计器处理截距,有时可能会冗余。我们在示例中将其设置为 False 以求清晰。可视化影响让我们可视化添加多项式特征如何使线性模型拟合非线性数据。我们将创建合成数据,其中 $y$ 大约是 $x$ 的二次函数,并带有一些噪声。import numpy as np import pandas as pd from sklearn.linear_model import LinearRegression from sklearn.preprocessing import PolynomialFeatures # 假设 plotly 已导入为 px,graph_objects 已导入为 go # 生成合成非线性数据 np.random.seed(42) n_samples = 100 X = np.random.rand(n_samples, 1) * 10 - 5 # 特征值在 -5 到 5 之间 y = 0.8 * X**2 + 0.5 * X + 2 + np.random.randn(n_samples, 1) * 4 # 二次关系 + 噪声 # 1. 拟合标准线性回归 linear_reg = LinearRegression() linear_reg.fit(X, y) y_pred_linear = linear_reg.predict(X) # 2. 创建多项式特征(2次)并拟合线性回归 poly_features = PolynomialFeatures(degree=2, include_bias=False) X_poly = poly_features.fit_transform(X) poly_reg = LinearRegression() poly_reg.fit(X_poly, y) # 在网格上创建预测以获得平滑的线 X_grid = np.arange(-5, 5, 0.1).reshape(-1, 1) X_grid_poly = poly_features.transform(X_grid) y_pred_poly = poly_reg.predict(X_grid_poly) y_pred_linear_grid = linear_reg.predict(X_grid) # 网格上的线性模型预测 # 创建 Plotly 图表 import plotly.graph_objects as go fig = go.Figure() # 添加原始数据的散点图 fig.add_trace(go.Scatter(x=X.flatten(), y=y.flatten(), mode='markers', name='原始数据', marker=dict(color='#228be6', opacity=0.7))) # 添加标准线性回归拟合的线 fig.add_trace(go.Scatter(x=X_grid.flatten(), y=y_pred_linear_grid.flatten(), mode='lines', name='线性拟合', line=dict(color='#fa5252', width=2))) # 添加多项式回归拟合的线 fig.add_trace(go.Scatter(x=X_grid.flatten(), y=y_pred_poly.flatten(), mode='lines', name='多项式拟合(2次)', line=dict(color='#51cf66', width=2))) fig.update_layout( title="线性回归与多项式回归拟合对比", xaxis_title="特征 (x)", yaxis_title="目标 (y)", legend_title="模型", template="plotly_white", width=700, height=400, margin=dict(l=20, r=20, t=50, b=20) # 减少边距 ) # fig.show() # 在实际环境中,这将显示图表 # 图表 JSON (单行用于嵌入) chart_json = fig.to_json(pretty=False) print(f"```plotly\n{chart_json}\n```"){"layout": {"title": {"text": "线性回归与多项式回归拟合对比"}, "xaxis": {"title": {"text": "特征 (x)"}}, "yaxis": {"title": {"text": "目标 (y)"}}, "legend": {"title": {"text": "模型"}}, "template": "plotly_white", "width": 700, "height": 400, "margin": {"l": 20, "r": 20, "t": 50, "b": 20}}, "data": [{"x": [-0.3169968816478094, 4.683003117882048, 3.311889308316901, -1.2405393931400336, 1.7633291816228717, -3.0311596800603747, 1.786466397940183, 3.538501508437093, -2.084658604655831, -3.441280687818902, 0.033225126147746015, -2.265319010553108, 2.701058046177716, 0.23376238081168927, -4.981167594293279, 2.8588532530004225, 0.5087034899981085, -4.184893571833233, -4.718530383129498, 3.2206396973860804, 1.2148291990307804, 1.3329795814219113, -3.7664259032803735, -1.8557628718620788, -4.006834007034896, -3.044299036464506, 4.389214750386672, 3.860590696403314, -3.935815085566817, -2.1162861078158975, 3.345337821895944, 0.9066531491339098, 3.1745465217257796, 2.662428204897618, 0.012309398881946855, -1.9337437440871296, -3.284714539311795, 2.061349125755306, 2.085375143130968, -1.4316902870943988, -1.3609674886972453, -2.502019153237703, 4.925770348477416, 0.683645168196383, 1.802187517593814, 0.3787429383956707, -0.8079057335713378, 4.581386239089789, 3.785859153342812, -2.863881943275401, 4.19928561904554, -4.656527041394458, -4.72158292903029, 0.9740066507990724, 1.3466990207164595, -0.2824946056398037, 0.5965825479010539, 4.202324528940351, -4.116519412814516, 0.9425667600861497, -3.4538014158375383, 1.584445810256237, 4.161617830639942, 1.819297108871498, 2.8964915441069686, 4.986666570214164, -4.286984363971083, 1.0799190687464462, 2.545799598124258, 1.684500241051265, -1.1434636105819482, -0.9771717468494863, -1.0568936676418548, 3.304802116138787, 0.9138786180047267, -2.273294174845506, -1.078050397628804, -1.9286787286199015, 3.307768180404041, 3.7895180714373766, 2.7096421102486597, -2.149401397551238, -3.242391936721521, 0.10441225250807872, 1.1468151571727106, -1.864983263411022, 4.906810467347392, -2.630538670288475, 3.0903121307614534, 1.838402191156523, -3.864335541416693, -0.6501584946012997, 1.186699423532178, 3.3734241558863413, 4.54942589185446, -0.10444086968649595, -3.016879833672155, -2.014696418108886, -3.1180760421995554, 2.242806970112028], "y": [-0.2508490784743358, 21.250401294057597, 11.96226380903608, 6.77475570829799, 5.444695928497629, 1.3509419586641198, 6.983881260222557, 13.93655810534231, 1.7253027149680964, 11.850076354379013, -0.5325319427611239, 3.4690912573491145, 10.390243395447876, 4.606358004879339, 19.28683031493583, 12.23826880752073, 0.4378850593658762, 11.362890000750318, 21.34626486437603, 12.060587751560197, 5.4918182997899175, 1.028332287998636, 16.747634271349297, 4.3383789635823095, 11.29613744566638, 5.455237782866536, 4.75545020257055, 3.860590696403314, 13.85417933162803, 7.40697671235648, 9.271928565680557, 1.7877922911072684, 12.15177881105881, 6.742805063590673, 1.4828247671476238, 4.105475119090252, 10.90053198256211, 4.149491178728224, 5.810254600638342, 2.44276688162874, 0.9745469183318344, 3.112877377835214, 18.051037575544318, 2.433445545364592, 6.030241422379411, 0.6356864474900321, 2.327724341259383, 20.425114150748734, 16.249037267981635, 9.539199891696465, 17.482180949109113, 15.97500173010791, 24.646733107348017, 3.927738399673398, 4.7837186763490525, 4.349225566921617, 3.7339864155258495, 14.024024284533728, 16.51962430383731, 3.977217885270857, 14.742658231160127, 1.1843468766729028, 11.134981735854559, 1.3802555437847483, 9.08246810125601, 27.45017472978159, 14.498052570404294, 1.8048279419737003, 4.603512174537882, 3.878066784059141, 2.939755312583758, 4.839629767698352, -1.2200060115836733, 12.22279311958967, 1.0945148620353784, 6.731100069402192, 0.3373625397758536, 3.784381390956626, 15.504792534626406, 11.502992796570876, 6.252961239213777, 1.327115263493503, 11.095461601179143, 3.0922368791605846, 2.565557634101092, 0.23336642514131748, 20.865277996369873, 8.914244044647397, 9.33710224629492, 4.365800665440033, 13.05263416249723, 1.4288571112138585, 3.749361050199475, 13.450356461040955, 20.71466896013303, 0.39970155706299424, 4.527623768528158, 0.9230665312540315, 10.84048982240676, 8.195713516384726], "type": "scatter", "mode": "markers", "name": "原始数据", "marker": {"color": "#228be6", "opacity": 0.7}}, {"x": [-5.0, -4.9, -4.8, -4.7, -4.6, -4.5, -4.4, -4.3, -4.2, -4.1, -4.0, -3.9, -3.8, -3.7, -3.6, -3.5, -3.4, -3.3, -3.2, -3.1, -3.0, -2.9, -2.8, -2.7, -2.6, -2.5, -2.4, -2.3, -2.2, -2.1, -2.0, -1.9, -1.8, -1.7, -1.6, -1.5, -1.4, -1.3, -1.2, -1.1, -1.0, -0.9, -0.8, -0.7, -0.6, -0.5, -0.4, -0.3, -0.2, -0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0, 2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3.0, 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9, 4.0, 4.1, 4.2, 4.3, 4.4, 4.5, 4.6, 4.7, 4.8, 4.9], "y": [10.547137148116835, 10.337189546821314, 10.127241945525792, 9.91729434423027, 9.707346742934748, 9.497399141639227, 9.287451540343705, 9.077503939048183, 8.867556337752661, 8.65760873645714, 8.447661135161618, 8.237713533866096, 8.027765932570575, 7.817818331275053, 7.607870729979531, 7.39792312868401, 7.187975527388488, 6.978027926092966, 6.7680803247974445, 6.558132723501923, 6.348185122206401, 6.138237520910879, 5.928289919615357, 5.718342318319836, 5.508394717024314, 5.298447115728792, 5.08849951443327, 4.8785519131377485, 4.668604311842227, 4.458656710546705, 4.248709109251183, 4.038761507955661, 3.8288139066601393, 3.6188663053646176, 3.4089187040690955, 3.1989711027735737, 2.9890235014780516, 2.7790759001825295, 2.5691282988870074, 2.359180697591486, 2.149233096295964, 1.9392854950004423, 1.7293378937049202, 1.5193902924093984, 1.3094426911138763, 1.0994950898183545, 0.8895474885228328, 0.6795998872273109, 0.469652285931789, 0.2597046846362672, 0.049757083340745346, -0.1601905179547765, -0.37013811925029835, -0.5800857205458202, -0.790033321841342, -0.9999809231368638, -1.2099285244323857, -1.4198761257279075, -1.6298237270234294, -1.8397713283189512, -2.049718929614473, -2.259666530909995, -2.469614132205517, -2.6795617335010387, -2.8895093347965606, -3.0994569360920824, -3.3094045373876043, -3.519352138683126, -3.729299739978648, -3.93924734127417, -4.1491949425696915, -4.359142543865213, -4.569090145160735, -4.779037746456257, -4.988985347751779, -5.1989329490473, -5.408880550342822, -5.618828151638344, -5.828775752933866, -6.038723354229388, -6.248670955524909, -6.458618556820431, -6.668566158115953, -6.878513759411475, -7.088461360706997, -7.298408962002518, -7.50835656329804, -7.718304164593562, -7.928251765889084, -8.138199367184605, -8.348146968480127, -8.558094569775649, -8.76804217107117, -8.977989772366692, -9.187937373662214, -9.397884974957736, -9.607832576253257, -9.817780177548779, -10.0277277788443, -10.237675380139822, -10.447622981435344, -10.657570582730866, -10.867518184026388, -11.07746578532191], "type": "scatter", "mode": "lines", "name": "线性拟合", "line": {"color": "#fa5252", "width": 2}}, {"x": [-5.0, -4.9, -4.8, -4.7, -4.6, -4.5, -4.4, -4.3, -4.2, -4.1, -4.0, -3.9, -3.8, -3.7, -3.6, -3.5, -3.4, -3.3, -3.2, -3.1, -3.0, -2.9, -2.8, -2.7, -2.6, -2.5, -2.4, -2.3, -2.2, -2.1, -2.0, -1.9, -1.8, -1.7, -1.6, -1.5, -1.4, -1.3, -1.2, -1.1, -1.0, -0.9, -0.8, -0.7, -0.6, -0.5, -0.4, -0.3, -0.2, -0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0, 2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3.0, 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9, 4.0, 4.1, 4.2, 4.3, 4.4, 4.5, 4.6, 4.7, 4.8, 4.9], "y": [19.971716170097253, 18.86552471618179, 17.800935835291854, 16.77794952742744, 15.79656579258855, 14.85678463077518, 13.958606041987336, 13.102029026224018, 12.287054583486222, 11.51368271377395, 10.7819134170872, 10.091746693425973, 9.443182542790267, 8.836220965180085, 8.270861960595426, 7.747105529036289, 7.264951670502674, 6.824399385000483, 6.425450672513713, 6.068103533048365, 5.75235796660444, 5.478213973181938, 5.2456715527808575, 5.054730705401199, 4.905391431043014, 4.797653729700201, 4.73151760137881, 4.706983046078841, 4.724050063799346, 4.782718654540323, 4.882988818301774, 5.024860555083698, 5.208333864886094, 5.433408747709014, 5.700085203552355, 6.00836323241617, 6.358242834299407, 6.749724009203066, 7.182806757127199, 7.657491078071754, 8.173776972036784, 8.731664438022237, 9.331153477028163, 9.972244089054562, 10.654936274101435, 11.37922903216878, 12.145124363256598, 12.95262126736489, 13.801719744493654, 14.69241979464289, 15.6247214178126, 16.598624613997783, 17.61412938320344, 18.671235725429568, 19.769943640676168, 20.910253128943243, 22.09216419023079, 23.31567682453881, 24.580791031867298, 25.88750681221626, 27.2358241655857, 28.62574309197561, 30.05726359138599, 31.530385663816845, 33.04510930926817, 34.60143452773997, 36.19936131923224, 37.83888968374498, 39.51901962127819, 41.24175113183187, 43.00608421540602, 44.812018871999645, 46.65955500161374, 48.54869270424731, 50.47943197989935, 52.45177282857186, 54.46571525026484, 56.52125924497829, 58.61840481271221, 60.757151953466604, 62.93750066724147, 65.15945095403681, 67.42300281385262, 69.7281562466889, 72.07491125254566, 74.46326783142289, 76.89322600332059, 79.36478574823877, 81.87794706617741, 84.43270995713652, 87.02907442111611, 89.66704045811617, 92.3466080681367, 95.0677772511777, 97.83054800723917, 100.63491033632111, 103.48087423842353, 106.36843971354641, 109.29760676168977, 112.2683753828536, 115.28074557703789, 118.33471734424266, 121.4302906844679], "type": "scatter", "mode": "lines", "name": "多项式拟合(2次)", "line": {"color": "#51cf66", "width": 2}}]}标准线性拟合(红线)未能捕捉数据的曲线。多项式拟合(绿线),使用2次特征($x$ 和 $x^2$),更好地对潜在的二次关系进行建模。考量与最佳实践虽然功能强大,但多项式特征需要仔细考量:选择次数:多项式的次数是一个超参数。低次数可能不够灵活,无法捕捉潜在模式(欠拟合),而非常高次则可能导致模型过于复杂,过度拟合训练数据中的噪声(过拟合)。最佳次数通常通过交叉验证来确定。维度爆炸:生成的特征数量随着次数和原始特征数量的增加而迅速增长。对于 $n$ 个原始特征和次数 $d$,结果特征的数量(包括偏置)由二项式系数 $\binom{n+d}{d} = \frac{(n+d)!}{d!n!}$ 给出。这会变得计算成本高昂,并增加过拟合的风险(“维度灾难”)。特征缩放:在应用 PolynomialFeatures 之前,通常重要对特征进行缩放(例如,使用 StandardScaler 或 MinMaxScaler)。这是因为高次多项式项可能会导致非常大或非常小的值,可能引起数值不稳定,或者使模型对具有自然较大范围的特征敏感。缩放可确保所有特征更均匀地贡献。正则化:在使用带有线性模型的多项式特征时,几乎总是建议使用正则化(如 Ridge、Lasso 或 ElasticNet)。正则化有助于约束模型系数,防止它们变得过大,从而减少过拟合,尤其是在存在许多多项式特征时。Lasso (L1 正则化) 甚至可以通过将某些系数精确地缩小到零来执行隐式特征选择。以下是您如何使用 Scikit-learn 管道集成缩放、多项式特征生成和正则化线性模型的方法:from sklearn.pipeline import Pipeline from sklearn.preprocessing import StandardScaler from sklearn.linear_model import Ridge # 假设前述示例中的 X 和 y 可用 # 创建一个管道 poly_pipeline = Pipeline([ ('scaler', StandardScaler()), # 首先缩放特征 ('poly', PolynomialFeatures(degree=2, include_bias=False)), # 生成多项式特征 ('ridge_reg', Ridge(alpha=1.0)) # 使用 Ridge 回归进行正则化 ]) # 拟合管道 poly_pipeline.fit(X, y) # 进行预测(管道处理缩放和转换) # y_pred_pipeline = poly_pipeline.predict(X) print("管道拟合成功。") # print("前5个预测:", y_pred_pipeline[:5].flatten())总而言之,多项式特征提供了一种直接的方法,为本质上是线性的模型增加非线性能力。通过生成平方项、立方项和交互项,您赋予这些模型学习更复杂模式的能力。然而,这种能力也伴随着管理增大的特征空间和过拟合的可能性的责任,通常需要仔细选择次数、特征缩放和正则化。