Double Machine Learning (DML) and Causal Forests are techniques for estimating effects in high-dimensional settings. These estimators are implemented using common Python libraries, allowing their application to realistic datasets. Simulated data, where the ground truth is known, will be used to verify the performance of these implementations.Our goal is to estimate both the Average Treatment Effect (ATE) using DML and the Conditional Average Treatment Effects (CATE) using Causal Forests.Setup and Data SimulationFirst, ensure you have the necessary libraries installed. We'll primarily use econml, scikit-learn, pandas, and numpy.import numpy as np import pandas as pd from sklearn.model_selection import train_test_split from sklearn.ensemble import GradientBoostingRegressor, GradientBoostingClassifier from sklearn.linear_model import LassoCV from econml.dml import LinearDML, CausalForestDML from econml.cate_interpreter import SingleTreeCateInterpreter import matplotlib.pyplot as plt # Used only for plotting example belowLet's simulate data with high-dimensional confounders ($X$), a binary treatment ($T$), and an outcome ($Y$). We'll design the simulation such that the treatment effect is heterogeneous, depending on one of the confounders.# Simulation parameters n_samples = 5000 # Number of samples n_features = 20 # Number of confounders true_ate = 1.5 # Base average treatment effect heterogeneity_slope = 0.5 # Slope for heterogeneity based on X0 # Generate confounders np.random.seed(42) X = np.random.normal(0, 1, size=(n_samples, n_features)) X_df = pd.DataFrame(X, columns=[f'X{i}' for i in range(n_features)]) # Generate treatment assignment (propensity score depends on X) # Simple logistic model for propensity propensity_coeffs = np.random.uniform(-0.5, 0.5, size=n_features) propensity_logit = X @ propensity_coeffs + np.random.normal(0, 0.1, size=n_samples) propensity = 1 / (1 + np.exp(-propensity_logit)) T = np.random.binomial(1, propensity, size=n_samples) # Generate outcome (depends on X, T, and incorporates heterogeneity) # Y = linear effect of X + treatment effect * T + noise outcome_coeffs = np.random.uniform(0, 1, size=n_features) # True CATE = true_ate + heterogeneity_slope * X[:, 0] true_cate = true_ate + heterogeneity_slope * X[:, 0] Y = X @ outcome_coeffs + true_cate * T + np.random.normal(0, 0.5, size=n_samples) print(f"Simulated data shapes:") print(f"X: {X.shape}, T: {T.shape}, Y: {Y.shape}") print(f"True Average Treatment Effect (based on simulation): {np.mean(true_cate):.4f}")This setup mimics a common scenario where many features potentially confound the treatment-outcome relationship, and the treatment's effectiveness varies across individuals based on their characteristics (here, specifically $X_0$).Implementing Double Machine Learning (DML) for ATEDML estimates the ATE by using machine learning models to partial out the effects of confounders $X$ from both the treatment $T$ and the outcome $Y$. This involves fitting two nuisance models:Outcome model: $E[Y | X]$Treatment model: $E[T | X]$ (propensity score model if T is binary)We then estimate the effect using the residuals from these models. The econml library simplifies this process. We'll use LinearDML which assumes a linear final stage for estimating the constant ATE.# Define nuisance models # For outcome model E[Y|X] model_y = GradientBoostingRegressor(n_estimators=100, max_depth=3, random_state=42) # For treatment model E[T|X] (propensity) model_t = GradientBoostingClassifier(n_estimators=100, max_depth=3, random_state=42) # Instantiate LinearDML estimator # We use discrete_treatment=True since T is binary dml_estimator = LinearDML(model_y=model_y, model_t=model_t, discrete_treatment=True, random_state=123) # Fit the estimator # We provide outcome Y, treatment T, confounders X, and optionally W (no effect modifiers here) dml_estimator.fit(Y, T, X=X) # Get the ATE estimate and confidence interval ate_estimate = dml_estimator.effect(T=1) # Effect of going from T=0 to T=1 ate_ci = dml_estimator.effect_interval(T=1, alpha=0.05) # 95% CI print(f"DML Estimated ATE: {ate_estimate[0]:.4f}") print(f"95% Confidence Interval: [{ate_ci[0]:.4f}, {ate_ci[1]:.4f}]") # Compare with the true ATE calculated earlier print(f"True Average Treatment Effect: {np.mean(true_cate):.4f}")The LinearDML handles the cross-fitting procedure internally to prevent overfitting and provides standard errors for inference. Compare the estimated ATE and its confidence interval to the true average effect from our simulation. They should be reasonably close, demonstrating DML's ability to recover the average effect despite high-dimensional confounding.Implementing Causal Forests for CATEWhile DML provides an estimate of the average effect, Causal Forests aim to reveal heterogeneity in treatment effects. They adapt the random forest algorithm to estimate CATE, $E[Y(1) - Y(0) | X=x]$. econml provides CausalForestDML, which integrates the DML residualization approach within the forest structure.# Instantiate CausalForestDML estimator # It uses DML principles for orthogonalization within the forest splits # We can specify nuisance models or use defaults (often Gradient Boosting) cf_estimator = CausalForestDML(model_y=GradientBoostingRegressor(n_estimators=100, max_depth=3, random_state=42), model_t=GradientBoostingClassifier(n_estimators=100, max_depth=3, random_state=42), discrete_treatment=True, n_estimators=1000, # More trees are generally better for forests min_samples_leaf=10, max_depth=10, random_state=123) # Fit the Causal Forest cf_estimator.fit(Y, T, X=X) # Estimate CATE for all samples in the dataset cate_estimates = cf_estimator.effect(X=X) print(f"Shape of CATE estimates: {cate_estimates.shape}") print(f"Example CATE estimates (first 5): {cate_estimates[:5].round(4)}")The cf_estimator.effect(X=X) returns an array of CATE estimates, one for each sample based on its features $X$.Visualizing CATE HeterogeneityTo understand how the treatment effect varies, we can visualize the estimated CATE against the feature driving the heterogeneity ($X_0$ in our simulation).# Create a scatter plot of estimated CATE vs X0 # For visualization, let's sample points to avoid overplotting sample_indices = np.random.choice(n_samples, 500, replace=False) x0_sample = X[sample_indices, 0] cate_sample = cate_estimates[sample_indices] # Define the Plotly chart data and layout plotly_fig = { "data": [ { "type": "scatter", "mode": "markers", "x": x0_sample.tolist(), # Use sampled data "y": cate_sample.tolist(), # Use sampled data "marker": { "color": "#228be6", # blue "size": 6, "opacity": 0.6 }, "name": "Estimated CATE" }, # Add line showing the true relationship for reference { "type": "scatter", "mode": "lines", "x": sorted(x0_sample), "y": (true_ate + heterogeneity_slope * np.sort(x0_sample)).tolist(), "line": { "color": "#f03e3e", # red "width": 2, "dash": "dash" }, "name": "True CATE" } ], "layout": { "title": "Estimated CATE vs. Feature X0", "xaxis": {"title": "Feature X0"}, "yaxis": {"title": "Estimated CATE"}, "showlegend": True, "legend": {"x": 0.01, "y": 0.99}, "width": 700, "height": 450, "template": "plotly_white" } }{"data": [{"type": "scatter", "mode": "markers", "x": [-0.3089, -0.4522, -0.3163, 0.3743, 1.0539, -0.4698, -0.978, -1.067, 0.25, -0.0259, -0.1857, 0.1344, -0.2578, -1.4583, 0.8002, 0.3481, 0.7188, -0.9133, 0.7026, 1.6303, 0.5508, 0.0853, 0.7656, -0.3841, -0.1647, -0.3595, -0.4784, -0.1148, -0.6979, -0.3817, 0.4352, 0.1373, -0.3052, 0.0529, -1.0864, 0.1823, -0.667, 0.1499, -1.4653, 0.0964, -1.1084, -0.8607, 0.3295, -0.9099, -0.3786, 0.8463, 0.4849, -0.7693, 0.4967, -0.1136, 0.9209, 1.6674, 0.5115, 1.364, -0.2186, -0.2529, -0.4933, 0.7816, -0.5733, -0.601, 0.1301, -1.0471, -0.1821, 0.2795, -0.4401, -1.158, -0.014, -1.019, -0.5047, 0.0557, -0.8719, -0.4768, 0.4476, -1.0597, -1.053, 1.5068, 0.0756, 0.1323, 1.0716, -0.3386, 1.0309, 0.946, -0.4893, -0.0942, 1.3846, 0.4428, -0.2258, 0.915, 0.1866, 0.7783, -0.0924, 0.6439, 0.4184, -0.8782, -0.4369, 0.5743, -0.3378, 0.5274, -0.0235, -0.1505, 0.9111, 0.2662, 1.7438, 1.4246, -0.3225, 0.143, 1.2369, -0.213, 0.2568, 0.4236, -0.6843, 0.6929, -0.8113, 0.4209, -0.3326, 0.8725, 0.9147, 0.8049, 1.5942, 1.1817, 0.3972, 0.0407, -0.8962, 0.6041, 0.3932, 1.2807, -0.2363, 0.9773, 1.2108, -1.2278, -0.6289, -0.1731, 0.4128, 0.0784, -0.9204, 0.4016, -0.2078, 0.2009, -0.8144, 1.4762, -0.6719, -0.5581, -1.2744, -0.2001, 0.4082, 0.2326, 0.1929, 1.0169, -0.1473, 0.3979, 0.5978, -1.5103, 1.1678, -1.1928, 0.1245, 0.7983, -0.3283, 0.5211, 0.5715, 0.1503, 0.7333, -0.7143, 1.3743, -0.0337, -1.4383, -1.6489, -0.1782, -0.4321, -1.1503, 0.0664, 0.471, 0.7407, -1.4437, 1.2646, -1.254, -0.3837, -0.7607, -1.6523, 0.0985, 0.4333, -1.2024, 0.0836, 0.8053, 0.7419, 1.3645, 1.0222, 0.0961, 1.4743, -1.3692, -0.6461, 0.9644, 1.596, 1.1468, 0.4819, -0.5838, 0.6112, -0.3102, -0.7625, 0.0672, -0.4226, 1.1707, 0.729, -1.157, -0.4669, 0.9562, 0.1457, 0.1447, -0.2933, -1.234, -0.0099, -1.0204, -1.1649, 0.083, 0.2502, 1.4381, -0.4173, 0.6685, 0.208, 1.023, -0.1222, -0.7439, -0.258, 0.1745, 0.0238, 1.1436, 0.0722, 0.7213, 1.0215, -1.6345, 0.2376, -0.7968, -1.2361, -1.0123, 0.2411, -1.6959, 0.7819, -0.2672, 0.7789, -0.1869, 1.1707, -0.6427, -0.3865, 0.0419, 0.1728, 0.514, 0.435, -1.2075, 0.5978, 0.982, -0.3816, 0.6776, 0.0354, -0.3727, 0.6948, 0.5685, 1.4403, -0.4107, -0.1927, 0.3025, 0.0875, 0.9065, -0.4113, 0.4591, 0.6296, 1.264, -0.7477, 0.5277, 0.2819, 0.3262, -0.6985, 1.2831, 0.1762, -0.3158, 1.206, 0.7361, -0.7991, 1.1975, 0.7011, -1.0702, 0.2067, -1.2564, 0.7004, 0.8779, -0.3971, 0.0342, -0.7096, 1.5856, 1.0948, 0.3925, -0.5767, -1.4994, 1.1214, -0.5012, -0.5817, 0.3908, -0.2219, -0.401, 0.6551, -0.6676, 0.8439, 0.5246, -1.1649, -0.5454, -0.428, 0.318, -0.3349, -0.7643, 1.3385, 0.633, 0.0258, -1.0527, 0.5994, -1.2823, 0.9146, -1.5838, 0.1941, 0.2255, 0.6442, -0.984, 0.9788, -0.4394, 0.1422, -0.6919, -0.1061, 1.3805, 1.4371, -1.0345, -0.4269, -0.553, -1.1811, -0.5129, -0.5082, -0.0335, -0.2086, -1.7383, -1.1715, -1.2204, -1.3065, -0.7297, -0.5783, 0.3127, -0.2646, 0.783, 0.9537, 0.9395, -0.3082, -0.2977, -1.3731, 0.0683, 0.617, 1.0039, -1.3012, 0.5547, 0.662, -1.1385, -1.0379, -0.0921, -1.7411, -1.6953, -0.1548, -0.671, -0.0049, 0.282, 0.4705, 0.2863, -0.1214, -0.3844, 0.0858, -0.5137, 0.5377, 1.1602, -1.1546, 0.9809, 0.5467, -0.0373, 0.048, 0.0659, 0.733, -1.1315, 0.8573, 0.2067, -0.3858, 0.155, -1.5075, 1.4882, 0.0895, 1.1335, 0.3538, -0.1046, -1.5574, -0.4922, -0.4256, -0.359, 0.0065, 0.2851, -0.2934, 0.7968, 0.8988, 0.697, -0.4345, 0.162, -0.1779, 0.0398, 1.1499, 0.513, -0.0375, -0.5713, -0.0518, -0.0659, 0.2552, 0.546, -0.0523, 0.2686, 0.5381, 1.1024, -1.3947, 0.609, 0.3009, 0.4856, 0.0836, 1.0503, 0.6731, 1.1209, -0.1043, 0.2269, 0.0428, 0.2497, 0.2579, -0.3443, -1.3035, 0.7971, 0.2498, -0.3293, 0.2463, 0.4544, -1.1792, 0.0572, 1.1004, 0.4699, 0.6883, -1.0802, 0.7459, -0.1608, -0.3879, -1.4875, -1.3625, -0.2512, -0.6315, 0.1758, 0.3721, 1.1467, -0.1218, 1.3995, -0.1722, 0.6838, 0.6262, -0.5121, 0.5547, 0.427, 1.127, 0.7363, 0.253, -0.4363, 0.6606, 0.4608, -0.8287, 0.4848, -0.2187, 0.4575, 1.3072, -0.0591, 0.7447, -0.0677, -1.4405, -0.6735, 0.0468, -0.0853], "y": [1.3593, 1.2311, 1.3704, 1.6682, 2.0214, 1.2437, 1.0157, 0.9434, 1.6093, 1.4695, 1.3929, 1.5461, 1.3517, 0.8012, 1.8953, 1.656, 1.8517, 1.0533, 1.8429, 2.3073, 1.7661, 1.5211, 1.8733, 1.2935, 1.4024, 1.3156, 1.2182, 1.4239, 1.1231, 1.3022, 1.6986, 1.545, 1.3463, 1.5053, 0.9515, 1.5691, 1.1418, 1.5567, 0.7818, 1.5252, 0.9277, 1.0587, 1.6439, 1.0359, 1.3027, 1.9146, 1.7307, 1.0994, 1.7375, 1.4277, 1.9512, 2.3251, 1.7459, 2.1763, 1.3807, 1.3571, 1.2118, 1.8828, 1.1828, 1.1688, 1.5426, 0.9698, 1.3954, 1.6207, 1.2477, 0.9011, 1.4767, 0.9841, 1.204, 1.5071, 1.0536, 1.2262, 1.7064, 0.9633, 0.9664, 2.2482, 1.5167, 1.5431, 2.0305, 1.3281, 2.0113, 1.9681, 1.2132, 1.4365, 2.186, 1.7038, 1.376, 1.9482, 1.5712, 1.8802, 1.438, 1.8188, 1.6925, 1.0496, 1.2584, 1.7785, 1.3341, 1.7545, 1.4715, 1.4091, 1.946, 1.6157, 2.3614, 2.2084, 1.3438, 1.5518, 2.1148, 1.3846, 1.6113, 1.6954, 1.1315, 1.836, 1.0773, 1.6935, 1.336, 1.9263, 1.9481, 1.8982, 2.2875, 2.0886, 1.6826, 1.4996, 1.0416, 1.7939, 1.6802, 2.1359, 1.3706, 1.9825, 2.0997, 0.8717, 1.1582, 1.3987, 1.6902, 1.5175, 1.0294, 1.6846, 1.381, 1.5812, 1.0757, 2.2325, 1.1387, 1.1894, 0.8492, 1.3867, 1.6876, 1.5987, 1.5756, 2.0044, 1.4115, 1.6828, 1.7905, 0.7587, 2.0816, 0.8887, 1.5399, 1.8935, 1.3387, 1.7512, 1.7771, 1.5567, 1.8584, 1.1146, 2.1808, 1.4681, 0.8016, 0.7045, 1.3964, 1.2521, 0.9051, 1.5123, 1.7232, 1.8625, 0.7971, 2.1282, 0.8591, 1.2942, 1.1038, 0.6973, 1.527, 1.7004, 0.8836, 1.5201, 1.8985, 1.8631, 2.1766, 2.0069, 1.5249, 2.2314, 0.816, 1.1501, 1.9767, 2.2885, 2.0715, 1.7289, 1.1775, 1.8001, 1.352, 1.1027, 1.513, 1.2644, 2.0833, 1.8561, 0.9014, 1.2301, 1.9721, 1.5541, 1.5529, 1.3491, 0.8687, 1.4799, 0.9834, 0.9004, 1.5197, 1.6093, 2.2149, 1.2698, 1.8292, 1.5859, 2.0077, 1.4228, 1.1022, 1.357, 1.565, 1.4915, 2.07, 1.5148, 1.852, 2.0059, 0.7114, 1.6012, 1.0846, 0.8676, 0.9871, 1.604, 0.676, 1.883, 1.3536, 1.8805, 1.3921, 2.0833, 1.152, 1.291, 1.5009, 1.5639, 1.7472, 1.7003, 0.879, 1.7905, 1.9854, 1.2951, 1.8324, 1.4977, 1.3009, 1.8384, 1.7755, 2.216, 1.2736, 1.3906, 1.6351, 1.522, 1.943, 1.2732, 1.7129, 1.8068, 2.1279, 1.1003, 1.7546, 1.6232, 1.6449, 1.1229, 2.1376, 1.5664, 1.345, 2.1036, 1.8597, 1.0837, 2.0967, 1.8418, 0.9585, 1.5849, 0.8579, 1.8415, 1.9295, 1.2869, 1.4972, 1.117, 2.283, 2.0444, 1.6796, 1.1809, 0.764, 2.058, 1.2062, 1.1787, 1.6787, 1.3771, 1.2772, 1.8212, 1.1416, 1.9133, 1.753, 0.9004, 1.1938, 1.2597, 1.6403, 1.331, 1.0986, 2.1639, 1.8083, 1.4927, 0.9666, 1.7914, 0.8444, 1.948, 0.7389, 1.5768, 1.5948, 1.819, 0.9993, 1.9833, 1.2545, 1.5511, 1.1282, 1.4332, 2.1839, 2.2144, 0.9759, 1.2605, 1.1919, 0.8937, 1.201, 1.2032, 1.4684, 1.3806, 0.6461, 0.9014, 0.8749, 0.834, 1.1105, 1.18, 1.6382, 1.3545, 1.8835, 1.9711, 1.9649, 1.3449, 1.3473, 0.8142, 1.5132, 1.802, 2.0001, 0.8369, 1.7684, 1.8255, 0.9116, 0.9743, 1.4383, 0.6447, 0.6762, 1.4062, 1.1395, 1.4837, 1.6233, 1.723, 1.6268, 1.4235, 1.2934, 1.5214, 1.2003, 1.7601, 2.0782, 0.9027, 1.9846, 1.764, 1.4654, 1.5027, 1.5121, 1.8582, 0.915, 1.9185, 1.5849, 1.2922, 1.5589, 0.7599, 2.2385, 1.5234, 2.064, 1.6591, 1.4336, 0.745, 1.2121, 1.2613, 1.3158, 1.4875, 1.6259, 1.349, 1.8926, 1.9401, 1.8393, 1.2502, 1.561, 1.3967, 1.50, 2.0727, 1.7467, 1.4652, 1.1834, 1.4538, 1.4455, 1.6108, 1.7638, 1.4534, 1.6175, 1.7604, 2.0486, 0.8046, 1.7983, 1.6342, 1.731, 1.5201, 2.0196, 1.8306, 2.0579, 1.4339, 1.596, 1.5015, 1.6088, 1.6117, 1.3249, 0.8359, 1.8929, 1.6088, 1.3378, 1.6065, 1.7082, 0.8955, 1.5082, 2.0478, 1.7227, 1.8344, 0.953, 1.866, 1.4041, 1.2902, 0.77, 0.819, 1.3584, 1.157, 1.5662, 1.6671, 2.0714, 1.4233, 2.1929, 1.4006, 1.8351, 1.8055, 1.2017, 1.7684, 1.698, 2.061, 1.8598, 1.6102, 1.2592, 1.8246, 1.7142, 1.0699, 1.7306, 1.3807, 1.7117, 2.1523, 1.4503, 1.8647, 1.445, 0.7994, 1.1377, 1.5022, 1.442], "marker": {"color": "#228be6", "size": 6, "opacity": 0.6}, "name": "Estimated CATE"}, {"type": "scatter", "mode": "lines", "x": [-1.7411, -1.7383, -1.6959, -1.6953, -1.6523, -1.6489, -1.6345, -1.5838, -1.5574, -1.5103, -1.5075, -1.4994, -1.4875, -1.4653, -1.4583, -1.4437, -1.4405, -1.4383, -1.3731, -1.3692, -1.3625, -1.3065, -1.3035, -1.3012, -1.2823, -1.2744, -1.2564, -1.254, -1.2361, -1.234, -1.2278, -1.2204, -1.2075, -1.2024, -1.1928, -1.1811, -1.1792, -1.1715, -1.1649, -1.1649, -1.158, -1.157, -1.1546, -1.1503, -1.1385, -1.1315, -1.1084, -1.0864, -1.0802, -1.0702, -1.067, -1.0597, -1.053, -1.0527, -1.0471, -1.0379, -1.0345, -1.0204, -1.019, -1.0123, -0.984, -0.978, -0.9204, -0.9133, -0.9099, -0.8962, -0.8782, -0.8719, -0.8607, -0.8287, -0.8144, -0.8113, -0.7991, -0.7968, -0.7693, -0.7643, -0.7625, -0.7607, -0.7477, -0.7439, -0.7297, -0.7143, -0.7096, -0.6985, -0.6979, -0.6919, -0.6843, -0.6735, -0.6719, -0.671, -0.6676, -0.667, -0.6461, -0.6427, -0.6315, -0.6289, -0.601, -0.5838, -0.5817, -0.5783, -0.5767, -0.5733, -0.5713, -0.5581, -0.553, -0.5454, -0.5137, -0.5129, -0.5121, -0.5082, -0.5047, -0.5012, -0.4933, -0.4922, -0.4893, -0.4784, -0.4768, -0.4698, -0.4669, -0.4522, -0.4401, -0.4394, -0.4369, -0.4363, -0.4345, -0.4321, -0.428, -0.4269, -0.4226, -0.4173, -0.4113, -0.4107, -0.401, -0.3971, -0.3879, -0.3865, -0.3858, -0.3844, -0.3841, -0.3837, -0.3817, -0.3816, -0.3786, -0.3727, -0.3595, -0.359, -0.3443, -0.3386, -0.3378, -0.3349, -0.3326, -0.3293, -0.3283, -0.3225, -0.3163, -0.3158, -0.3102, -0.3089, -0.3082, -0.3052, -0.2977, -0.2934, -0.2933, -0.2672, -0.2646, -0.258, -0.2578, -0.2529, -0.2512, -0.2363, -0.2258, -0.2219, -0.2187, -0.2186, -0.213, -0.2086, -0.2078, -0.2001, -0.1927, -0.1869, -0.1857, -0.1821, -0.1782, -0.1779, -0.1731, -0.1722, -0.1647, -0.1608, -0.1548, -0.1505, -0.1473, -0.1222, -0.1218, -0.1214, -0.1148, -0.1136, -0.1061, -0.1046, -0.1043, -0.0942, -0.0924, -0.0921, -0.0853, -0.0677, -0.0659, -0.0591, -0.0523, -0.0518, -0.0375, -0.0373, -0.0337, -0.0335, -0.0259, -0.0235, -0.014, -0.0099, -0.0049, 0.0065, 0.0238, 0.0258, 0.0342, 0.0354, 0.0398, 0.0407, 0.0419, 0.0428, 0.0468, 0.048, 0.0529, 0.0557, 0.0572, 0.0659, 0.0664, 0.0672, 0.0683, 0.0722, 0.0756, 0.0784, 0.083, 0.0836, 0.0836, 0.0853, 0.0858, 0.0875, 0.0895, 0.0961, 0.0964, 0.0985, 0.1245, 0.1301, 0.1323, 0.1344, 0.1373, 0.1422, 0.143, 0.1447, 0.1457, 0.1499, 0.1503, 0.155, 0.162, 0.1728, 0.1745, 0.1758, 0.1762, 0.1823, 0.1866, 0.1929, 0.1941, 0.2009, 0.2067, 0.2067, 0.208, 0.2255, 0.2269, 0.2326, 0.2376, 0.2411, 0.2463, 0.2497, 0.2498, 0.25, 0.2502, 0.253, 0.2552, 0.2568, 0.2579, 0.2662, 0.2686, 0.2795, 0.2819, 0.282, 0.2851, 0.2863, 0.3009, 0.3025, 0.3127, 0.318, 0.3262, 0.3295, 0.3481, 0.3538, 0.3721, 0.3743, 0.3908, 0.3925, 0.3932, 0.3972, 0.3979, 0.4016, 0.4082, 0.4128, 0.4184, 0.4209, 0.4236, 0.427, 0.4333, 0.435, 0.4352, 0.4428, 0.4476, 0.4544, 0.4575, 0.4591, 0.4608, 0.4699, 0.4705, 0.471, 0.4819, 0.4848, 0.4849, 0.4856, 0.4967, 0.5115, 0.513, 0.514, 0.5211, 0.5246, 0.5274, 0.5277, 0.5377, 0.5381, 0.546, 0.5467, 0.5508, 0.5547, 0.5547, 0.5685, 0.5715, 0.5743, 0.5978, 0.5978, 0.5994, 0.6041, 0.609, 0.6112, 0.617, 0.6262, 0.6296, 0.633, 0.6439, 0.6442, 0.6551, 0.6606, 0.662, 0.6685, 0.6731, 0.6776, 0.6838, 0.6883, 0.6929, 0.6948, 0.697, 0.7004, 0.7011, 0.7026, 0.7188, 0.7213, 0.729, 0.733, 0.7333, 0.7361, 0.7363, 0.7407, 0.7419, 0.7447, 0.7459, 0.7656, 0.7783, 0.7789, 0.7816, 0.7819, 0.783, 0.7968, 0.7971, 0.7983, 0.8002, 0.8049, 0.8053, 0.8439, 0.8463, 0.8573, 0.8725, 0.8779, 0.8988, 0.9065, 0.9111, 0.9146, 0.9147, 0.915, 0.9209, 0.9395, 0.946, 0.9537, 0.9562, 0.9644, 0.9773, 0.9788, 0.9809, 0.982, 1.0039, 1.0169, 1.0215, 1.0222, 1.023, 1.0309, 1.0503, 1.0539, 1.0716, 1.0948, 1.1004, 1.1024, 1.1209, 1.1214, 1.127, 1.1335, 1.1436, 1.1467, 1.1468, 1.1499, 1.1602, 1.1678, 1.1707, 1.1707, 1.1817, 1.1975, 1.206, 1.2108, 1.2369, 1.264, 1.2646, 1.2807, 1.2831, 1.3072, 1.3385, 1.364, 1.3645, 1.3743, 1.3805, 1.3846, 1.3995, 1.4246, 1.4371, 1.4381, 1.4403, 1.4743, 1.4762, 1.4882, 1.5068, 1.5856, 1.5942, 1.596, 1.6303, 1.6674, 1.7438], "y": [0.6295, 0.6308, 0.652, 0.6523, 0.6738, 0.6756, 0.6828, 0.7081, 0.7213, 0.7449, 0.7462, 0.7503, 0.7562, 0.7674, 0.7709, 0.7781, 0.7797, 0.7809, 0.8135, 0.8154, 0.8188, 0.8467, 0.8482, 0.8494, 0.8588, 0.8628, 0.8718, 0.873, 0.8819, 0.883, 0.8861, 0.8898, 0.8962, 0.8988, 0.9036, 0.9094, 0.9104, 0.9142, 0.9176, 0.9176, 0.921, 0.9215, 0.9227, 0.9249, 0.9308, 0.9343, 0.9458, 0.9568, 0.9599, 0.9649, 0.9665, 0.9702, 0.9735, 0.9736, 0.9764, 0.9811, 0.983, 0.9893, 0.9898, 0.9938, 1.008, 1.011, 1.0398, 1.0433, 1.045, 1.0519, 1.0609, 1.064, 1.0696, 1.0856, 1.0929, 1.0943, 1.1004, 1.1011, 1.1154, 1.1179, 1.1189, 1.1197, 1.1261, 1.128, 1.1353, 1.1429, 1.1452, 1.1507, 1.1511, 1.1541, 1.1584, 1.1633, 1.1645, 1.1667, 1.167, 1.1769, 1.1787, 1.1891, 1.1911, 1.1941, 1.1966, 1.1986, 1.2019, 1.2039, 1.2081, 1.2085, 1.2092, 1.2156, 1.2161, 1.2197, 1.2206, 1.2241, 1.2265, 1.2338, 1.2344, 1.2357, 1.2379, 1.2399, 1.2469, 1.2484, 1.2503, 1.2513, 1.2586, 1.2593, 1.2603, 1.2623, 1.2637, 1.2685, 1.2702, 1.2722, 1.2738, 1.2788, 1.2817, 1.2834, 1.2865, 1.2869, 1.2893, 1.2913, 1.2916, 1.2935, 1.294, 1.295, 1.2964, 1.2976, 1.302, 1.3031, 1.3049, 1.3058, 1.3169, 1.3185, 1.321, 1.3211, 1.3258, 1.3288, 1.3294, 1.3335, 1.3357, 1.3371, 1.3415, 1.3419, 1.3427, 1.3439, 1.3475, 1.3499, 1.3515, 1.3565, 1.3606, 1.3609, 1.3627, 1.3649, 1.3669, 1.3707, 1.3714, 1.373, 1.3739, 1.375, 1.3775, 1.381, 1.3811, 1.3815, 1.383, 1.3835, 1.3981, 1.3988, 1.4006, 1.4013, 1.4037, 1.4055, 1.4108, 1.4113, 1.4115, 1.4118, 1.4121, 1.4123, 1.4155, 1.4172, 1.4185, 1.424, 1.4255, 1.4271, 1.428, 1.4285, 1.4314, 1.4332, 1.435, 1.4374, 1.4383, 1.4391, 1.4405, 1.441, 1.442, 1.4425, 1.4442, 1.4447, 1.4475, 1.4494, 1.4495, 1.45, 1.4501, 1.4515, 1.4563, 1.458, 1.4613, 1.4631, 1.464, 1.4645, 1.4648, 1.4654, 1.4662, 1.4675, 1.4678, 1.4714, 1.4728, 1.473, 1.4735, 1.474, 1.475, 1.4771, 1.4832, 1.4835, 1.4839, 1.4855, 1.4871, 1.4873, 1.4874, 1.4877, 1.488, 1.489, 1.4905, 1.4914, 1.4914, 1.4917, 1.4924, 1.4961, 1.497, 1.4972, 1.4983, 1.4991, 1.4992, 1.4994, 1.5009, 1.501, 1.5014, 1.5022, 1.5039, 1.5067, 1.5072, 1.509, 1.5094, 1.5118, 1.5125, 1.5144, 1.5158, 1.516, 1.5171, 1.5182, 1.5238, 1.5241, 1.5256, 1.5261, 1.5279, 1.5326, 1.5328, 1.5332, 1.5339, 1.5341, 1.5351, 1.536, 1.5361, 1.5365, 1.5368, 1.5381, 1.5407, 1.5415, 1.5424, 1.5427, 1.5431, 1.546, 1.5466, 1.5481, 1.5482, 1.5483, 1.5485, 1.551, 1.5545, 1.5546, 1.5588, 1.559, 1.5595, 1.56, 1.5605, 1.5649, 1.566, 1.567, 1.5675, 1.5678, 1.5682, 1.5688, 1.5693, 1.5714, 1.5718, 1.5732, 1.5732, 1.5738, 1.574, 1.5784, 1.5787, 1.579, 1.5798, 1.581, 1.5816, 1.5816, 1.584, 1.5856, 1.5866, 1.5873, 1.5888, 1.5888, 1.589, 1.5894, 1.59, 1.5925, 1.5946, 1.5966, 1.5969, 1.5984, 1.5996, 1.602, 1.603, 1.603, 1.6094, 1.6106, 1.6108, 1.611, 1.6115, 1.6123, 1.6156, 1.6165, 1.6166, 1.617, 1.6172, 1.6184, 1.6214, 1.6215, 1.6217, 1.6222, 1.6248, 1.6253, 1.6255, 1.6259, 1.6304, 1.6316, 1.6323, 1.6323, 1.634, 1.6341, 1.6349, 1.6485, 1.651, 1.6514, 1.6534, 1.6535, 1.6619, 1.6632, 1.6637, 1.6685, 1.6724, 1.6747, 1.6752, 1.6805, 1.6807, 1.6835, 1.685, 1.6868, 1.6872, 1.6874, 1.6899, 1.6951, 1.6973, 1.7005, 1.702, 1.7031, 1.7095, 1.7104, 1.7121, 1.718, 1.7181, 1.7183, 1.7214, 1.7244, 1.7254, 1.7287, 1.7316, 1.7338, 1.737, 1.7371, 1.7375, 1.7499, 1.7584, 1.7618, 1.7685, 1.7688, 1.7689, 1.7698, 1.7813, 1.7843, 1.7939, 1.798, 1.7985, 1.7989, 1.8171, 1.8185, 1.8188, 1.8216, 1.825, 1.8337, 1.8688, 1.8719, 1.878, 1.8828, 1.9071, 1.9281, 1.9315, 1.9425, 1.9837, 2.0152, 2.0337, 2.0717, 2.1719], "line": {"color": "#f03e3e", "width": 2, "dash": "dash"}, "name": "True CATE"}], "layout": {"title": "Estimated CATE vs. Feature X0", "xaxis": {"title": "Feature X0"}, "yaxis": {"title": "Estimated CATE"}, "showlegend": true, "legend": {"x": 0.01, "y": 0.99}, "width": 700, "height": 450, "template": "plotly_white"}}Estimated Conditional Average Treatment Effects (CATE) plotted against the values of feature X0 for a sample of the data. The dashed red line indicates the true CATE relationship defined in the simulation ($CATE = 1.5 + 0.5 * X_0$).The plot should show that the Causal Forest's CATE estimates generally follow the true positive slope, indicating that the model successfully captured the increasing treatment effect with higher values of $X_0$. The scatter represents the individual CATE estimates, which will naturally have some variance around the true line.You can further interpret the CATE estimates using tools like econml.cate_interpreter.SingleTreeCateInterpreter to understand which features are most important for driving heterogeneity.# Interpret the CATE model with a simple tree intrp = SingleTreeCateInterpreter(include_model_uncertainty=False, max_depth=2) intrp.interpret(cf_estimator, X) # Plot the interpretation tree (requires graphviz) # intrp.plot(feature_names=X_df.columns, fontsize=12) # The plot would show splits primarily on X0 if the model works well. print("CATE Interpretation Tree Structure:") print(intrp.text_summary(feature_names=X_df.columns))This summary provides a simplified tree structure approximating the Causal Forest's CATE function, often highlighting the most influential features (like $X_0$ in our case).Summary and Next StepsThis practice session demonstrated the implementation of Double Machine Learning for ATE estimation and Causal Forests for CATE estimation in a high-dimensional setting using econml. We saw how DML effectively recovers the average effect by handling confounders with nuisance models, and how Causal Forests can reveal heterogeneity in treatment effects.Important takeaways:DML relies on accurate modeling of nuisance functions ($E[Y|X]$ and $E[T|X]$). The choice of ML models for these tasks is important.Causal Forests build upon DML principles to estimate how effects vary with individual characteristics $X$.Using simulated data allows for verification of the methods by comparing estimates against known ground truth.From here, you can experiment with different models within DML and Causal Forests (e.g., LassoCV vs. GradientBoosting), tune hyperparameters, and apply these techniques to your own datasets. Remember to carefully consider the assumptions underlying these methods, particularly the unconfoundedness assumption (conditional ignorability) given the observed covariates $X$. The next chapter explores situations where this assumption is violated due to unobserved confounding.