双重机器学习 (DML) 和因果森林是在高维环境中估计效应的技术。使用常用Python库实现这些估计量,以使它们能够应用于实际数据集。我们将使用已知真实情况的模拟数据,从而可以验证实现的效果。我们的目标是使用DML估计平均处理效应 (ATE),并使用因果森林估计条件平均处理效应 (CATE)。设置与数据模拟首先,请确保您已安装必要的库。我们将主要使用 econml、scikit-learn、pandas 和 numpy。import numpy as np import pandas as pd from sklearn.model_selection import train_test_split from sklearn.ensemble import GradientBoostingRegressor, GradientBoostingClassifier from sklearn.linear_model import LassoCV from econml.dml import LinearDML, CausalForestDML from econml.cate_interpreter import SingleTreeCateInterpreter import matplotlib.pyplot as plt # 仅用于以下绘图示例让我们模拟具有高维混杂因素 ($X$)、二元处理 ($T$) 和结果 ($Y$) 的数据。我们将设计模拟,使处理效应具有异质性,取决于某个混杂因素。# 模拟参数 n_samples = 5000 # 样本数量 n_features = 20 # 混杂因素数量 true_ate = 1.5 # 基本平均处理效应 heterogeneity_slope = 0.5 # 基于X0的异质性斜率 # 生成混杂因素 np.random.seed(42) X = np.random.normal(0, 1, size=(n_samples, n_features)) X_df = pd.DataFrame(X, columns=[f'X{i}' for i in range(n_features)]) # 生成处理分配(倾向得分取决于X) # 简单的倾向得分逻辑模型 propensity_coeffs = np.random.uniform(-0.5, 0.5, size=n_features) propensity_logit = X @ propensity_coeffs + np.random.normal(0, 0.1, size=n_samples) propensity = 1 / (1 + np.exp(-propensity_logit)) T = np.random.binomial(1, propensity, size=n_samples) # 生成结果(取决于X、T并包含异质性) # Y = X的线性效应 + 处理效应 * T + 噪声 outcome_coeffs = np.random.uniform(0, 1, size=n_features) # 真实CATE = true_ate + heterogeneity_slope * X[:, 0] true_cate = true_ate + heterogeneity_slope * X[:, 0] Y = X @ outcome_coeffs + true_cate * T + np.random.normal(0, 0.5, size=n_samples) print(f"模拟数据形状:") print(f"X: {X.shape}, T: {T.shape}, Y: {Y.shape}") print(f"(基于模拟的)真实平均处理效应: {np.mean(true_cate):.4f}")此设置模拟了一个常见情况,即许多特征可能混杂处理-结果关系,并且处理的有效性根据个体特征(此处具体指 $X_0$)而异。实现用于ATE的双重机器学习 (DML)DML通过使用机器学习模型从处理 $T$ 和结果 $Y$ 中部分去除混杂因素 $X$ 的效应来估计ATE。这包括拟合两个辅助模型:结果模型:$E[Y | X]$处理模型:$E[T | X]$(如果T是二元,则为倾向得分模型)然后,我们使用这些模型的残差来估计效应。econml 库简化了此过程。我们将使用 LinearDML,它假设最终阶段是线性的,用于估计恒定的ATE。# 定义辅助模型 # 对于结果模型 E[Y|X] model_y = GradientBoostingRegressor(n_estimators=100, max_depth=3, random_state=42) # 对于处理模型 E[T|X](倾向得分) model_t = GradientBoostingClassifier(n_estimators=100, max_depth=3, random_state=42) # 实例化LinearDML估计量 # 由于T是二元的,我们使用discrete_treatment=True dml_estimator = LinearDML(model_y=model_y, model_t=model_t, discrete_treatment=True, random_state=123) # 拟合估计量 # 我们提供结果Y、处理T、混杂因素X,以及可选的W(此处无效应修正因子) dml_estimator.fit(Y, T, X=X) # 获取ATE估计值和置信区间 ate_estimate = dml_estimator.effect(T=1) # 从T=0到T=1的效应 ate_ci = dml_estimator.effect_interval(T=1, alpha=0.05) # 95% 置信区间 print(f"DML估计的ATE: {ate_estimate[0]:.4f}") print(f"95% 置信区间: [{ate_ci[0]:.4f}, {ate_ci[1]:.4f}]") # 与之前计算的真实ATE进行比较 print(f"真实平均处理效应: {np.mean(true_cate):.4f}")LinearDML 在内部处理交叉拟合过程,以防止过拟合并为推断提供标准误差。将估计的ATE及其置信区间与我们模拟的真实平均效应进行比较。它们应该合理接近,表明DML即使在高维混杂情况下也能恢复平均效应。实现用于CATE的因果森林DML提供的是平均效应的估计,而因果森林则旨在表明处理效应中的异质性。它们调整了随机森林算法来估计CATE,$E[Y(1) - Y(0) | X=x]$。econml 提供 CausalForestDML,它将DML残差化方法整合到森林结构中。# 实例化CausalForestDML估计量 # 它使用DML原则在森林分裂内进行正交化 # 我们可以指定辅助模型或使用默认设置(通常是梯度提升) cf_estimator = CausalForestDML(model_y=GradientBoostingRegressor(n_estimators=100, max_depth=3, random_state=42), model_t=GradientBoostingClassifier(n_estimators=100, max_depth=3, random_state=42), discrete_treatment=True, n_estimators=1000, # 更多的树通常对森林效果更佳 min_samples_leaf=10, max_depth=10, random_state=123) # 拟合因果森林 cf_estimator.fit(Y, T, X=X) # 估计数据集中所有样本的CATE cate_estimates = cf_estimator.effect(X=X) print(f"CATE估计的形状: {cate_estimates.shape}") print(f"CATE估计示例(前5个): {cate_estimates[:5].round(4)}")cf_estimator.effect(X=X) 返回一个CATE估计数组,每个样本根据其特征 $X$ 对应一个估计值。CATE异质性可视化为了理解处理效应如何变化,我们可以将估计的CATE与导致异质性的特征(在我们的模拟中是 $X_0$)进行可视化。# 创建估计CATE与X0的散点图 # 为了可视化,我们抽取一些点以避免重叠 sample_indices = np.random.choice(n_samples, 500, replace=False) x0_sample = X[sample_indices, 0] cate_sample = cate_estimates[sample_indices] # 定义Plotly图表数据和布局 plotly_fig = { "data": [ { "type": "scatter", "mode": "markers", "x": x0_sample.tolist(), # 使用抽样数据 "y": cate_sample.tolist(), # 使用抽样数据 "marker": { "color": "#228be6", # 蓝色 "size": 6, "opacity": 0.6 }, "name": "估计CATE" }, # 添加显示真实关系的线条以供参考 { "type": "scatter", "mode": "lines", "x": sorted(x0_sample), "y": (true_ate + heterogeneity_slope * np.sort(x0_sample)).tolist(), "line": { "color": "#f03e3e", # 红色 "width": 2, "dash": "dash" }, "name": "真实CATE" } ], "layout": { "title": "估计CATE与特征X0对比", "xaxis": {"title": "特征X0"}, "yaxis": {"title": "估计CATE"}, "showlegend": True, "legend": {"x": 0.01, "y": 0.99}, "width": 700, "height": 450, "template": "plotly_white" } }{ "data": [ { "x": [-0.3089, -0.4522, -0.3163, 0.3743, 1.0539, -0.4698, -0.978, -1.067, 0.25, -0.0259, -0.1857, 0.1344, -0.2578, -1.4583, 0.8002, 0.3481, 0.7188, -0.9133, 0.7026, 1.6303, 0.5508, 0.0853, 0.7656, -0.3841, -0.1647, -0.3595, -0.4784, -0.1148, -0.6979, -0.3817, 0.4352, 0.1373, -0.3052, 0.0529, -1.0864, 0.1823, -0.667, 0.1499, -1.4653, 0.0964, -1.1084, -0.8607, 0.3295, -0.9099, -0.3786, 0.8463, 0.4849, -0.7693, 0.4967, -0.1136, 0.9209, 1.6674, 0.5115, 1.364, -0.2186, -0.2529, -0.4933, 0.7816, -0.5733, -0.601, 0.1301, -1.0471, -0.1821, 0.2795, -0.4401, -1.158, -0.014, -1.019, -0.5047, 0.0557, -0.8719, -0.4768, 0.4476, -1.0597, -1.053, 1.5068, 0.0756, 0.1323, 1.0716, -0.3386, 1.0309, 0.946, -0.4893, -0.0942, 1.3846, 0.4428, -0.2258, 0.915, 0.1866, 0.7783, -0.0924, 0.6439, 0.4184, -0.8782, -0.4369, 0.5743, -0.3378, 0.5274, -0.0235, -0.1505, 0.9111, 0.2662, 1.7438, 1.4246, -0.3225, 0.143, 1.2369, -0.213, 0.2568, 0.4236, -0.6843, 0.6929, -0.8113, 0.4209, -0.3326, 0.8725, 0.9147, 0.8049, 1.5942, 1.1817, 0.3972, 0.0407, -0.8962, 0.6041, 0.3932, 1.2807, -0.2363, 0.9773, 1.2108, -1.2278, -0.6289, -0.1731, 0.4128, 0.0784, -0.9204, 0.4016, -0.2078, 0.2009, -0.8144, 1.4762, -0.6719, -0.5581, -1.2744, -0.2001, 0.4082, 0.2326, 0.1929, 1.0169, -0.1473, 0.3979, 0.5978, -1.5103, 1.1678, -1.1928, 0.1245, 0.7983, -0.3283, 0.5211, 0.5715, 0.1503, 0.7333, -0.7143, 1.3743, -0.0337, -1.4383, -1.6489, -0.1782, -0.4321, -1.1503, 0.0664, 0.471, 0.7407, -1.4437, 1.2646, -1.254, -0.3837, -0.7607, -1.6523, 0.0985, 0.4333, -1.2024, 0.0836, 0.8053, 0.7419, 1.3645, 1.0222, 0.0961, 1.4743, -1.3692, -0.6461, 0.9644, 1.596, 1.1468, 0.4819, -0.5838, 0.6112, -0.3102, -0.7625, 0.0672, -0.4226, 1.1707, 0.729, -1.157, -0.4669, 0.9562, 0.1457, 0.1447, -0.2933, -1.234, -0.0099, -1.0204, -1.1649, 0.083, 0.2502, 1.4381, -0.4173, 0.6685, 0.208, 1.023, -0.1222, -0.7439, -0.258, 0.1745, 0.0238, 1.1436, 0.0722, 0.7213, 1.0215, -1.6345, 0.2376, -0.7968, -1.2361, -1.0123, 0.2411, -1.6959, 0.7819, -0.2672, 0.7789, -0.1869, 1.1707, -0.6427, -0.3865, 0.0419, 0.1728, 0.514, 0.435, -1.2075, 0.5978, 0.982, -0.3816, 0.6776, 0.0354, -0.3727, 0.6948, 0.5685, 1.4403, -0.4107, -0.1927, 0.3025, 0.0875, 0.9065, -0.4113, 0.4591, 0.6296, 1.264, -0.7477, 0.5277, 0.2819, 0.3262, -0.6985, 1.2831, 0.1762, -0.3158, 1.206, 0.7361, -0.7991, 1.1975, 0.7011, -1.0702, 0.2067, -1.2564, 0.7004, 0.8779, -0.3971, 0.0342, -0.7096, 1.5856, 1.0948, 0.3925, -0.5767, -1.4994, 1.1214, -0.5012, -0.5817, 0.3908, -0.2219, -0.401, 0.6551, -0.6676, 0.8439, 0.5246, -1.1649, -0.5454, -0.428, 0.318, -0.3349, -0.7643, 1.3385, 0.633, 0.0258, -1.0527, 0.5994, -1.2823, 0.9146, -1.5838, 0.1941, 0.2255, 0.6442, -0.984, 0.9788, -0.4394, 0.1422, -0.6919, -0.1061, 1.3805, 1.4371, -1.0345, -0.4269, -0.553, -1.1811, -0.5129, -0.5082, -0.0335, -0.2086, -1.7383, -1.1715, -1.2204, -1.3065, -0.7297, -0.5783, 0.3127, -0.2646, 0.783, 0.9537, 0.9395, -0.3082, -0.2977, -1.3731, 0.0683, 0.617, 1.0039, -1.3012, 0.5547, 0.662, -1.1385, -1.0379, -0.0921, -1.7411, -1.6953, -0.1548, -0.671, -0.0049, 0.282, 0.4705, 0.2863, -0.1214, -0.3844, 0.0858, -0.5137, 0.5377, 1.1602, -1.1546, 0.9809, 0.5467, -0.0373, 0.048, 0.0659, 0.733, -1.1315, 0.8573, 0.2067, -0.3858, 0.155, -1.5075, 1.4882, 0.0895, 1.1335, 0.3538, -0.1046, -1.5574, -0.4922, -0.4256, -0.359, 0.0065, 0.2851, -0.2934, 0.7968, 0.8988, 0.697, -0.4345, 0.162, -0.1779, 0.0398, 1.1499, 0.513, -0.0375, -0.5713, -0.0518, -0.0659, 0.2552, 0.546, -0.0523, 0.2686, 0.5381, 1.1024, -1.3947, 0.609, 0.3009, 0.4856, 0.0836, 1.0503, 0.6731, 1.1209, -0.1043, 0.2269, 0.0428, 0.2497, 0.2579, -0.3443, -1.3035, 0.7971, 0.2498, -0.3293, 0.2463, 0.4544, -1.1792, 0.0572, 1.1004, 0.4699, 0.6883, -1.0802, 0.7459, -0.1608, -0.3879, -1.4875, -1.3625, -0.2512, -0.6315, 0.1758, 0.3721, 1.1467, -0.1218, 1.3995, -0.1722, 0.6838, 0.6262, -0.5121, 0.5547, 0.427, 1.127, 0.7363, 0.253, -0.4363, 0.6606, 0.4608, -0.8287, 0.4848, -0.2187, 0.4575, 1.3072, -0.0591, 0.7447, -0.0677, -1.4405, -0.6735, 0.0468, -0.0853], "y": [1.3593, 1.2311, 1.3704, 1.6682, 2.0214, 1.2437, 1.0157, 0.9434, 1.6093, 1.4695, 1.3929, 1.5461, 1.3517, 0.8012, 1.8953, 1.656, 1.8517, 1.0533, 1.8429, 2.3073, 1.7661, 1.5211, 1.8733, 1.2935, 1.4024, 1.3156, 1.2182, 1.4239, 1.1231, 1.3022, 1.6986, 1.545, 1.3463, 1.5053, 0.9515, 1.5691, 1.1418, 1.5567, 0.7818, 1.5252, 0.9277, 1.0587, 1.6439, 1.0359, 1.3027, 1.9146, 1.7307, 1.0994, 1.7375, 1.4277, 1.9512, 2.3251, 1.7459, 2.1763, 1.3807, 1.3571, 1.2118, 1.8828, 1.1828, 1.1688, 1.5426, 0.9698, 1.3954, 1.6207, 1.2477, 0.9011, 1.4767, 0.9841, 1.204, 1.5071, 1.0536, 1.2262, 1.7064, 0.9633, 0.9664, 2.2482, 1.5167, 1.5431, 2.0305, 1.3281, 2.0113, 1.9681, 1.2132, 1.4365, 2.186, 1.7038, 1.376, 1.9482, 1.5712, 1.8802, 1.438, 1.8188, 1.6925, 1.0496, 1.2584, 1.7785, 1.3341, 1.7545, 1.4715, 1.4091, 1.946, 1.6157, 2.3614, 2.2084, 1.3438, 1.5518, 2.1148, 1.3846, 1.6113, 1.6954, 1.1315, 1.836, 1.0773, 1.6935, 1.336, 1.9263, 1.9481, 1.8982, 2.2875, 2.0886, 1.6826, 1.4996, 1.0416, 1.7939, 1.6802, 2.1359, 1.3706, 1.9825, 2.0997, 0.8717, 1.1582, 1.3987, 1.6902, 1.5175, 1.0294, 1.6846, 1.381, 1.5812, 1.0757, 2.2325, 1.1387, 1.1894, 0.8492, 1.3867, 1.6876, 1.5987, 1.5756, 2.0044, 1.4115, 1.6828, 1.7905, 0.7587, 2.0816, 0.8887, 1.5399, 1.8935, 1.3387, 1.7512, 1.7771, 1.5567, 1.8584, 1.1146, 2.1808, 1.4681, 0.8016, 0.7045, 1.3964, 1.2521, 0.9051, 1.5123, 1.7232, 1.8625, 0.7971, 2.1282, 0.8591, 1.2942, 1.1038, 0.6973, 1.527, 1.7004, 0.8836, 1.5201, 1.8985, 1.8631, 2.1766, 2.0069, 1.5249, 2.2314, 0.816, 1.1501, 1.9767, 2.2885, 2.0715, 1.7289, 1.1775, 1.8001, 1.352, 1.1027, 1.513, 1.2644, 2.0833, 1.8561, 0.9014, 1.2301, 1.9721, 1.5541, 1.5529, 1.3491, 0.8687, 1.4799, 0.9834, 0.9004, 1.5197, 1.6093, 2.2149, 1.2698, 1.8292, 1.5859, 2.0077, 1.4228, 1.1022, 1.357, 1.565, 1.4915, 2.07, 1.5148, 1.852, 2.0059, 0.7114, 1.6012, 1.0846, 0.8676, 0.9871, 1.604, 0.676, 1.883, 1.3536, 1.8805, 1.3921, 2.0833, 1.152, 1.291, 1.5009, 1.5639, 1.7472, 1.7003, 0.879, 1.7905, 1.9854, 1.2951, 1.8324, 1.4977, 1.3009, 1.8384, 1.7755, 2.216, 1.2736, 1.3906, 1.6351, 1.522, 1.943, 1.2732, 1.7129, 1.8068, 2.1279, 1.1003, 1.7546, 1.6232, 1.6449, 1.1229, 2.1376, 1.5664, 1.345, 2.1036, 1.8597, 1.0837, 2.0967, 1.8418, 0.9585, 1.5849, 0.8579, 1.8415, 1.9295, 1.2869, 1.4972, 1.117, 2.283, 2.0444, 1.6796, 1.1809, 0.764, 2.058, 1.2062, 1.1787, 1.6787, 1.3771, 1.2772, 1.8212, 1.1416, 1.9133, 1.753, 0.9004, 1.1938, 1.2597, 1.6403, 1.331, 1.0986, 2.1639, 1.8083, 1.4927, 0.9666, 1.7914, 0.8444, 1.948, 0.7389, 1.5768, 1.5948, 1.819, 0.9993, 1.9833, 1.2545, 1.5511, 1.1282, 1.4332, 2.1839, 2.2144, 0.9759, 1.2605, 1.1919, 0.8937, 1.201, 1.2032, 1.4684, 1.3806, 0.6461, 0.9014, 0.8749, 0.834, 1.1105, 1.18, 1.6382, 1.3545, 1.8835, 1.9711, 1.9649, 1.3449, 1.3473, 0.8142, 1.5132, 1.802, 2.0001, 0.8369, 1.7684, 1.8255, 0.9116, 0.9743, 1.4383, 0.6447, 0.6762, 1.4062, 1.1395, 1.4837, 1.6233, 1.723, 1.6268, 1.4235, 1.2934, 1.5214, 1.2003, 1.7601, 2.0782, 0.9027, 1.9846, 1.764, 1.4654, 1.5027, 1.5121, 1.8582, 0.915, 1.9185, 1.5849, 1.2922, 1.5589, 0.7599, 2.2385, 1.5234, 2.064, 1.6591, 1.4336, 0.745, 1.2121, 1.2613, 1.3158, 1.4875, 1.6259, 1.349, 1.8926, 1.9401, 1.8393, 1.2502, 1.561, 1.3967, 1.50, 2.0727, 1.7467, 1.4652, 1.1834, 1.4538, 1.4455, 1.6108, 1.7638, 1.4534, 1.6175, 1.7604, 2.0486, 0.8046, 1.7983, 1.6342, 1.731, 1.5201, 2.0196, 1.8306, 2.0579, 1.4339, 1.596, 1.5015, 1.6088, 1.6117, 1.3249, 0.8359, 1.8929, 1.6088, 1.3378, 1.6065, 1.7082, 0.8955, 1.5082, 2.0478, 1.7227, 1.8344, 0.953, 1.866, 1.4041, 1.2902, 0.77, 0.819, 1.3584, 1.157, 1.5662, 1.6671, 2.0714, 1.4233, 2.1929, 1.4006, 1.8351, 1.8055, 1.2017, 1.7684, 1.698, 2.061, 1.8598, 1.6102, 1.2592, 1.8246, 1.7142, 1.0699, 1.7306, 1.3807, 1.7117, 2.1523, 1.4503, 1.8647, 1.445, 0.7994, 1.1377, 1.5022, 1.442], "marker": {"color": "#228be6", "size": 6, "opacity": 0.6}, "name": "估计CATE"}, {"type": "scatter", "mode": "lines", "x": [-1.7411, -1.7383, -1.6959, -1.6953, -1.6523, -1.6489, -1.6345, -1.5838, -1.5574, -1.5103, -1.5075, -1.4994, -1.4875, -1.4653, -1.4583, -1.4437, -1.4405, -1.4383, -1.3731, -1.3692, -1.3625, -1.3065, -1.3035, -1.3012, -1.2823, -1.2744, -1.2564, -1.254, -1.2361, -1.234, -1.2278, -1.2204, -1.2075, -1.2024, -1.1928, -1.1811, -1.1792, -1.1715, -1.1649, -1.1649, -1.158, -1.157, -1.1546, -1.1503, -1.1385, -1.1315, -1.1084, -1.0864, -1.0802, -1.0702, -1.067, -1.0597, -1.053, -1.0527, -1.0471, -1.0379, -1.0345, -1.0204, -1.019, -1.0123, -0.984, -0.978, -0.9204, -0.9133, -0.9099, -0.8962, -0.8782, -0.8719, -0.8607, -0.8287, -0.8144, -0.8113, -0.7991, -0.7968, -0.7693, -0.7643, -0.7625, -0.7607, -0.7477, -0.7439, -0.7297, -0.7143, -0.7096, -0.6985, -0.6979, -0.6919, -0.6843, -0.6735, -0.6719, -0.671, -0.6676, -0.667, -0.6461, -0.6427, -0.6315, -0.6289, -0.601, -0.5838, -0.5817, -0.5783, -0.5767, -0.5733, -0.5713, -0.5581, -0.553, -0.5454, -0.5137, -0.5129, -0.5121, -0.5082, -0.5047, -0.5012, -0.4933, -0.4922, -0.4893, -0.4784, -0.4768, -0.4698, -0.4669, -0.4522, -0.4401, -0.4394, -0.4369, -0.4363, -0.4345, -0.4321, -0.428, -0.4269, -0.4226, -0.4173, -0.4113, -0.4107, -0.401, -0.3971, -0.3879, -0.3865, -0.3858, -0.3844, -0.3841, -0.3837, -0.3817, -0.3816, -0.3786, -0.3727, -0.3595, -0.359, -0.3443, -0.3386, -0.3378, -0.3349, -0.3326, -0.3293, -0.3283, -0.3225, -0.3163, -0.3158, -0.3102, -0.3089, -0.3082, -0.3052, -0.2977, -0.2934, -0.2933, -0.2672, -0.2646, -0.258, -0.2578, -0.2529, -0.2512, -0.2363, -0.2258, -0.2219, -0.2187, -0.2186, -0.213, -0.2086, -0.2078, -0.2001, -0.1927, -0.1869, -0.1857, -0.1821, -0.1782, -0.1779, -0.1731, -0.1722, -0.1647, -0.1608, -0.1548, -0.1505, -0.1473, -0.1222, -0.1218, -0.1214, -0.1148, -0.1136, -0.1061, -0.1046, -0.1043, -0.0942, -0.0924, -0.0921, -0.0853, -0.0677, -0.0659, -0.0591, -0.0523, -0.0518, -0.0375, -0.0373, -0.0337, -0.0335, -0.0259, -0.0235, -0.014, -0.0099, -0.0049, 0.0065, 0.0238, 0.0258, 0.0342, 0.0354, 0.0398, 0.0407, 0.0419, 0.0428, 0.0468, 0.048, 0.0529, 0.0557, 0.0572, 0.0659, 0.0664, 0.0672, 0.0683, 0.0722, 0.0756, 0.0784, 0.083, 0.0836, 0.0836, 0.0853, 0.0858, 0.0875, 0.0895, 0.0961, 0.0964, 0.0985, 0.1245, 0.1301, 0.1323, 0.1344, 0.1373, 0.1422, 0.143, 0.1447, 0.1457, 0.1499, 0.1503, 0.155, 0.162, 0.1728, 0.1745, 0.1758, 0.1762, 0.1823, 0.1866, 0.1929, 0.1941, 0.2009, 0.2067, 0.2067, 0.208, 0.2255, 0.2269, 0.2326, 0.2376, 0.2411, 0.2463, 0.2497, 0.2498, 0.25, 0.2502, 0.253, 0.2552, 0.2568, 0.2579, 0.2662, 0.2686, 0.2795, 0.2819, 0.282, 0.2851, 0.2863, 0.3009, 0.3025, 0.3127, 0.318, 0.3262, 0.3295, 0.3481, 0.3538, 0.3721, 0.3743, 0.3908, 0.3925, 0.3932, 0.3972, 0.3979, 0.4016, 0.4082, 0.4128, 0.4184, 0.4209, 0.4236, 0.427, 0.4333, 0.435, 0.4352, 0.4428, 0.4476, 0.4544, 0.4575, 0.4591, 0.4608, 0.4699, 0.4705, 0.471, 0.4819, 0.4848, 0.4849, 0.4856, 0.4967, 0.5115, 0.513, 0.514, 0.5211, 0.5246, 0.5274, 0.5277, 0.5377, 0.5381, 0.546, 0.5467, 0.5508, 0.5547, 0.5547, 0.5685, 0.5715, 0.5743, 0.5978, 0.5978, 0.5994, 0.6041, 0.609, 0.6112, 0.617, 0.6262, 0.6296, 0.633, 0.6439, 0.6442, 0.6551, 0.6606, 0.662, 0.6685, 0.6731, 0.6776, 0.6838, 0.6883, 0.6929, 0.6948, 0.697, 0.7004, 0.7011, 0.7026, 0.7188, 0.7213, 0.729, 0.733, 0.7333, 0.7361, 0.7363, 0.7407, 0.7419, 0.7447, 0.7459, 0.7656, 0.7783, 0.7789, 0.7816, 0.7819, 0.783, 0.7968, 0.7971, 0.7983, 0.8002, 0.8049, 0.8053, 0.8439, 0.8463, 0.8573, 0.8725, 0.8779, 0.8988, 0.9065, 0.9111, 0.9146, 0.9147, 0.915, 0.9209, 0.9395, 0.946, 0.9537, 0.9562, 0.9644, 0.9773, 0.9788, 0.9809, 0.982, 1.0039, 1.0169, 1.0215, 1.0222, 1.023, 1.0309, 1.0503, 1.0539, 1.0716, 1.0948, 1.1004, 1.1024, 1.1209, 1.1214, 1.127, 1.1335, 1.1436, 1.1467, 1.1468, 1.1499, 1.1602, 1.1678, 1.1707, 1.1707, 1.1817, 1.1975, 1.206, 1.2108, 1.2369, 1.264, 1.2646, 1.2807, 1.2831, 1.3072, 1.3385, 1.364, 1.3645, 1.3743, 1.3805, 1.3846, 1.3995, 1.4246, 1.4371, 1.4381, 1.4403, 1.4743, 1.4762, 1.4882, 1.5068, 1.5856, 1.5942, 1.596, 1.6303, 1.6674, 1.7438], "y": [0.6295, 0.6308, 0.652, 0.6523, 0.6738, 0.6756, 0.6828, 0.7081, 0.7213, 0.7449, 0.7462, 0.7503, 0.7562, 0.7674, 0.7709, 0.7781, 0.7797, 0.7809, 0.8135, 0.8154, 0.8188, 0.8467, 0.8482, 0.8494, 0.8588, 0.8628, 0.8718, 0.873, 0.8819, 0.883, 0.8861, 0.8898, 0.8962, 0.8988, 0.9036, 0.9094, 0.9104, 0.9142, 0.9176, 0.9176, 0.921, 0.9215, 0.9227, 0.9249, 0.9308, 0.9343, 0.9458, 0.9568, 0.9599, 0.9649, 0.9665, 0.9702, 0.9735, 0.9736, 0.9764, 0.9811, 0.983, 0.9893, 0.9898, 0.9938, 1.008, 1.011, 1.0398, 1.0433, 1.045, 1.0519, 1.0609, 1.064, 1.0696, 1.0856, 1.0929, 1.0943, 1.1004, 1.1011, 1.1154, 1.1179, 1.1189, 1.1197, 1.1261, 1.128, 1.1353, 1.1429, 1.1452, 1.1507, 1.1511, 1.1541, 1.1584, 1.1633, 1.1645, 1.1667, 1.167, 1.1769, 1.1787, 1.1891, 1.1911, 1.1941, 1.1966, 1.1986, 1.2019, 1.2039, 1.2081, 1.2085, 1.2092, 1.2156, 1.2161, 1.2197, 1.2206, 1.2241, 1.2265, 1.2338, 1.2344, 1.2357, 1.2379, 1.2399, 1.2469, 1.2484, 1.2503, 1.2513, 1.2586, 1.2593, 1.2603, 1.2623, 1.2637, 1.2685, 1.2702, 1.2722, 1.2738, 1.2788, 1.2817, 1.2834, 1.2865, 1.2869, 1.2893, 1.2913, 1.2916, 1.2935, 1.294, 1.295, 1.2964, 1.2976, 1.302, 1.3031, 1.3049, 1.3058, 1.3169, 1.3185, 1.321, 1.3211, 1.3258, 1.3288, 1.3294, 1.3335, 1.3357, 1.3371, 1.3415, 1.3419, 1.3427, 1.3439, 1.3475, 1.3499, 1.3515, 1.3565, 1.3606, 1.3609, 1.3627, 1.3649, 1.3669, 1.3707, 1.3714, 1.373, 1.3739, 1.375, 1.3775, 1.381, 1.3811, 1.3815, 1.383, 1.3835, 1.3981, 1.3988, 1.4006, 1.4013, 1.4037, 1.4055, 1.4108, 1.4113, 1.4115, 1.4118, 1.4121, 1.4123, 1.4155, 1.4172, 1.4185, 1.424, 1.4255, 1.4271, 1.428, 1.4285, 1.4314, 1.4332, 1.435, 1.4374, 1.4383, 1.4391, 1.4405, 1.441, 1.442, 1.4425, 1.4442, 1.4447, 1.4475, 1.4494, 1.4495, 1.45, 1.4501, 1.4515, 1.4563, 1.458, 1.4613, 1.4631, 1.464, 1.4645, 1.4648, 1.4654, 1.4662, 1.4675, 1.4678, 1.4714, 1.4728, 1.473, 1.4735, 1.474, 1.475, 1.4771, 1.4832, 1.4835, 1.4839, 1.4855, 1.4871, 1.4873, 1.4874, 1.4877, 1.488, 1.489, 1.4905, 1.4914, 1.4914, 1.4917, 1.4924, 1.4961, 1.497, 1.4972, 1.4983, 1.4991, 1.4992, 1.4994, 1.5009, 1.501, 1.5014, 1.5022, 1.5039, 1.5067, 1.5072, 1.509, 1.5094, 1.5118, 1.5125, 1.5144, 1.5158, 1.516, 1.5171, 1.5182, 1.5238, 1.5241, 1.5256, 1.5261, 1.5279, 1.5326, 1.5328, 1.5332, 1.5339, 1.5341, 1.5351, 1.536, 1.5361, 1.5365, 1.5368, 1.5381, 1.5407, 1.5415, 1.5424, 1.5427, 1.5431, 1.546, 1.5466, 1.5481, 1.5482, 1.5483, 1.5485, 1.551, 1.5545, 1.5546, 1.5588, 1.559, 1.5595, 1.56, 1.5605, 1.5649, 1.566, 1.567, 1.5675, 1.5678, 1.5682, 1.5688, 1.5693, 1.5714, 1.5718, 1.5732, 1.5732, 1.5738, 1.574, 1.5784, 1.5787, 1.579, 1.5798, 1.581, 1.5816, 1.5816, 1.584, 1.5856, 1.5866, 1.5873, 1.5888, 1.5888, 1.589, 1.5894, 1.59, 1.5925, 1.5946, 1.5966, 1.5969, 1.5984, 1.5996, 1.602, 1.603, 1.603, 1.6094, 1.6106, 1.6108, 1.611, 1.6115, 1.6123, 1.6156, 1.6165, 1.6166, 1.617, 1.6172, 1.6184, 1.6214, 1.6215, 1.6217, 1.6222, 1.6248, 1.6253, 1.6255, 1.6259, 1.6304, 1.6316, 1.6323, 1.6323, 1.634, 1.6341, 1.6349, 1.6485, 1.651, 1.6514, 1.6534, 1.6535, 1.6619, 1.6632, 1.6637, 1.6685, 1.6724, 1.6747, 1.6752, 1.6805, 1.6807, 1.6835, 1.685, 1.6868, 1.6872, 1.6874, 1.6899, 1.6951, 1.6973, 1.7005, 1.702, 1.7031, 1.7095, 1.7104, 1.7121, 1.718, 1.7181, 1.7183, 1.7214, 1.7244, 1.7254, 1.7287, 1.7316, 1.7338, 1.737, 1.7371, 1.7375, 1.7499, 1.7584, 1.7618, 1.7685, 1.7688, 1.7689, 1.7698, 1.7813, 1.7843, 1.7939, 1.798, 1.7985, 1.7989, 1.8171, 1.8185, 1.8188, 1.8216, 1.825, 1.8337, 1.8688, 1.8719, 1.878, 1.8828, 1.9071, 1.9281, 1.9315, 1.9425, 1.9837, 2.0152, 2.0337, 2.0717, 2.1719], "line": {"color": "#f03e3e", "width": 2, "dash": "dash"}, "name": "真实CATE"}], "layout": {"title": "估计CATE与特征X0对比", "xaxis": {"title": "特征X0"}, "yaxis": {"title": "估计CATE"}, "showlegend": true, "legend": {"x": 0.01, "y": 0.99}, "width": 700, "height": 450, "template": "plotly_white"}}对部分数据的特征X0值绘制的估计条件平均处理效应 (CATE)。虚线红线表明模拟中定义的真实CATE关系 ($CATE = 1.5 + 0.5 * X_0$)。该图应显示因果森林的CATE估计值通常遵循真实的向上斜率,这表明模型成功地捕捉到了处理效应随着 $X_0$ 值升高而增加的趋势。散点图表示个体CATE估计值,它们自然会在真实线周围存在一些偏差。您可以进一步解释CATE估计值,使用 econml.cate_interpreter.SingleTreeCateInterpreter 等工具,以理解哪些特征对引起异质性最重要。# 用简单树解释CATE模型 intrp = SingleTreeCateInterpreter(include_model_uncertainty=False, max_depth=2) intrp.interpret(cf_estimator, X) # 绘制解释树(需要graphviz) # intrp.plot(feature_names=X_df.columns, fontsize=12) # 如果模型运行良好,该图将主要显示基于X0的分裂。 print("CATE解释树结构:") print(intrp.text_summary(feature_names=X_df.columns))此总结提供了一个简化的树结构,近似于因果森林的CATE函数,通常突出显示最具影响力的特征(如我们案例中的 $X_0$)。总结与后续步骤本次实践环节演示了如何使用 econml 在高维环境中实现双重机器学习进行ATE估计,以及因果森林进行CATE估计。我们看到了DML如何通过辅助模型处理混杂因素,从而有效地恢复平均效应,以及因果森林如何表明处理效应中的异质性。要点:DML依赖于辅助函数 ($E[Y|X]$ 和 $E[T|X]$) 的准确建模。为这些任务选择机器学习模型很重要。因果森林建立在DML原则之上,以估计效应如何随个体特征 $X$ 变化。使用模拟数据可以验证这些方法,通过将估计值与已知真实情况进行比较。接下来,您可以在DML和因果森林中尝试不同的模型(例如 LassoCV 与 GradientBoosting),调整超参数,并将这些技术应用于您自己的数据集。请记住仔细考虑这些方法的基本假设,特别是给定观测协变量 $X$ 的无混杂性假设(条件可忽略性)。下一章将研究由于未观测到的混杂因素导致此假设被违反的情况。