While Double Machine Learning and Causal Forests leverage machine learning components for estimating nuisance functions (like propensity scores or conditional outcome expectations), deep learning provides a different avenue: designing end-to-end neural network architectures specifically for causal effect estimation. These approaches are particularly compelling when dealing with very high-dimensional data (e.g., images, text) or when intricate, non-linear relationships are suspected among confounders, treatment, and outcome.
The motivation stems from the powerful representation learning capabilities of deep neural networks. Instead of relying on pre-specified feature interactions or basis expansions, deep learning models can potentially learn complex data representations ϕ(X) from high-dimensional covariates X that are optimized for the causal task at hand.
Neural Network Architectures for Causal Estimation
Several architectures adapt standard neural networks for estimating potential outcomes and treatment effects. Two foundational examples are TARNet and DragonNet.
TARNet employs a simple but effective architecture. It uses a shared set of bottom layers to learn a representation ϕ(X) directly from the covariates X. This shared representation then feeds into two separate "heads": shallow networks that estimate the expected potential outcomes under treatment (E[Y(1)∣X=x]) and control (E[Y(0)∣X=x]), respectively.
Structure of TARNet. Input covariates X pass through shared layers learning a representation ϕ(X). This representation is used by separate heads to predict potential outcomes under control (h0) and treatment (h1).
The model is trained end-to-end by minimizing the empirical risk based on the observed factual outcomes:
where ℓ is a suitable loss function (e.g., mean squared error for continuous outcomes). The idea is that forcing the network to learn a shared representation ϕ(X) that is predictive of the outcome Y for both treatment groups encourages the network to learn features that are relevant for outcome prediction, implicitly helping to balance the covariate distributions between the treated and control groups within the representation space. The CATE can then be estimated as:
CATE(x)=h1(ϕ(x))−h0(ϕ(x))
DragonNet
DragonNet extends TARNet by explicitly incorporating a propensity score estimation head. It recognizes that the representation ϕ(X) needs to be good not only for predicting outcomes but also for adjusting for confounding bias. The architecture includes three components stemming from the shared representation layers:
Outcome Head for Control:h0(ϕ(X)) estimates E[Y(0)∣X=x].
Outcome Head for Treatment:h1(ϕ(X)) estimates E[Y(1)∣X=x].
Propensity Head:e(ϕ(X)) estimates the propensity score P(T=1∣X=x).
Structure of DragonNet. Similar to TARNet but adds a third head to predict the propensity score e(ϕ(X)), alongside the outcome heads h0 and h1.
The training objective for DragonNet typically combines the factual outcome loss (as in TARNet) with a loss for the propensity score prediction (e.g., binary cross-entropy):
LDragonNet=Lfactual+αLpropensity
where Lpropensity=n1∑i=1nℓBCE(Ti,e(ϕ(Xi))) and α is a hyperparameter balancing the two tasks. The intuition is that forcing the representation ϕ(X) to also be predictive of the treatment assignment T acts as a "targeted regularization", ensuring that the representation captures confounding information effectively. This can lead to better-calibrated CATE estimates: CATE(x)=h1(ϕ(x))−h0(ϕ(x)). Some variations also add terms to the loss function to explicitly encourage the distribution of ϕ(X) to be similar for treated and control groups, using techniques like Maximum Mean Discrepancy (MMD).
Other Deep Learning Approaches
The flexibility of neural networks has led to other architectures:
Causal Effect Variational Autoencoders (CEVAE): These models use variational autoencoders to explicitly model latent confounders alongside observed ones, attempting to infer causal effects even when some confounding is unobserved (requiring specific assumptions).
Generative Adversarial Networks (GANs): Architectures like GANITE use GANs to learn to generate counterfactual outcomes directly, framing causal inference as a missing data problem where the GAN fills in the unobserved potential outcome.
These models often address more complex scenarios or make different assumptions, representing the ongoing research in applying deep learning to causality.
Implementation Approaches
Successfully applying these deep learning models requires careful attention to several aspects:
Objective Functions: As seen, the loss function often balances multiple objectives: predicting factual outcomes, predicting propensity scores, and sometimes explicitly matching representation distributions. The relative weighting of these terms is a significant hyperparameter.
Regularization: Standard deep learning regularization techniques like dropout, weight decay (L2 regularization), and early stopping are essential to prevent overfitting, particularly given the complexity of the models and the potential to overfit the treatment assignment mechanism.
Hyperparameter Tuning: Neural network architectures (number of layers, nodes per layer, activation functions), learning rates, batch sizes, and regularization strengths must be tuned carefully. Validation requires specific strategies suitable for CATE models, as discussed later in Section 3.6 ("Validation and Calibration of CATE Estimators"). Standard cross-validation can be misleading if not adapted for causal estimation.
Software: Implementations often rely on standard deep learning frameworks like TensorFlow or PyTorch, sometimes using specialized libraries built on top that provide wrappers for these causal architectures.
Advantages and Disadvantages
Deep learning approaches offer distinct benefits and drawbacks:
Advantages:
Flexibility: Can model highly complex, non-linear relationships without manual feature engineering.
Representation Learning: Effective at handling unstructured or very high-dimensional data (images, text) by learning relevant features.
End-to-End Training: Potential to optimize all components (representation, outcome prediction, propensity estimation) jointly.
Disadvantages:
Data Requirements: Typically require large datasets to train effectively and avoid overfitting.
Computational Cost: Training can be computationally intensive compared to DML or Causal Forests.
Interpretability: Neural networks are often treated as "black boxes," making it harder to understand how they arrive at a specific CATE estimate.
Sensitivity: Performance can be highly sensitive to network architecture, hyperparameters, and initialization. Careful tuning and validation are necessary.
Risk of Bias Amplification: Without proper regularization or architectural design (like DragonNet's propensity head), they might inadvertently rely on variables predictive of treatment but not outcome, potentially amplifying bias.
Deep learning methods represent a powerful set of tools for causal effect estimation in high dimensions, especially when dealing with complex data structures and non-linearities where methods like DML or Causal Forests might be less suitable. However, their successful application demands significant expertise in both deep learning and causal inference principles, along with careful implementation and validation.
Was this section helpful?
Learning Representations for Counterfactual Inference, Fredrik D. Johansson, Uri Shalit, David Sontag, 2016Proceedings of The 33rd International Conference on Machine Learning (ICML), Vol. 48 (JMLR.org) - Introduces the Treatment-Agnostic Representation Network (TARNet) for individual treatment effect estimation.
Causal Effect Inference with Deep Latent Variable Models, Christos Louizos, Uri Shalit, Joris M. Mooij, David Sontag, Richard Zemel, Max Welling, 2017Advances in Neural Information Processing Systems 30, Vol. 30 (Neural Information Processing Systems Foundation, Inc. (NeurIPS))DOI: 10.48550/arXiv.1705.08821 - Presents Causal Effect Variational Autoencoders (CEVAE) for causal inference with latent confounders.