While Double Machine Learning and Causal Forests leverage machine learning components for estimating nuisance functions (like propensity scores or conditional outcome expectations), deep learning provides a different avenue: designing end-to-end neural network architectures specifically for causal effect estimation. These approaches are particularly compelling when dealing with very high-dimensional data (e.g., images, text) or when intricate, non-linear relationships are suspected among confounders, treatment, and outcome.
The motivation stems from the powerful representation learning capabilities of deep neural networks. Instead of relying on pre-specified feature interactions or basis expansions, deep learning models can potentially learn complex data representations ϕ(X) from high-dimensional covariates X that are optimized for the causal task at hand.
Several architectures adapt standard neural networks for estimating potential outcomes and treatment effects. Two foundational examples are TARNet and DragonNet.
TARNet employs a simple but effective architecture. It uses a shared set of bottom layers to learn a representation ϕ(X) directly from the covariates X. This shared representation then feeds into two separate "heads": shallow networks that estimate the expected potential outcomes under treatment (E[Y(1)∣X=x]) and control (E[Y(0)∣X=x]), respectively.
Structure of TARNet. Input covariates X pass through shared layers learning a representation ϕ(X). This representation is used by separate heads to predict potential outcomes under control (h0) and treatment (h1).
The model is trained end-to-end by minimizing the empirical risk based on the observed factual outcomes:
Lfactual=n1i=1∑nℓ(Yi,Tih1(ϕ(Xi))+(1−Ti)h0(ϕ(Xi)))where ℓ is a suitable loss function (e.g., mean squared error for continuous outcomes). The key idea is that forcing the network to learn a shared representation ϕ(X) that is predictive of the outcome Y for both treatment groups encourages the network to learn features that are relevant for outcome prediction, implicitly helping to balance the covariate distributions between the treated and control groups within the representation space. The CATE can then be estimated as:
CATE(x)=h1(ϕ(x))−h0(ϕ(x))DragonNet extends TARNet by explicitly incorporating a propensity score estimation head. It recognizes that the representation ϕ(X) needs to be good not only for predicting outcomes but also for adjusting for confounding bias. The architecture includes three components stemming from the shared representation layers:
Structure of DragonNet. Similar to TARNet but adds a third head to predict the propensity score e(ϕ(X)), alongside the outcome heads h0 and h1.
The training objective for DragonNet typically combines the factual outcome loss (as in TARNet) with a loss for the propensity score prediction (e.g., binary cross-entropy):
LDragonNet=Lfactual+αLpropensitywhere Lpropensity=n1∑i=1nℓBCE(Ti,e(ϕ(Xi))) and α is a hyperparameter balancing the two tasks. The intuition is that forcing the representation ϕ(X) to also be predictive of the treatment assignment T acts as a "targeted regularization", ensuring that the representation captures confounding information effectively. This can lead to better-calibrated CATE estimates: CATE(x)=h1(ϕ(x))−h0(ϕ(x)). Some variations also add terms to the loss function to explicitly encourage the distribution of ϕ(X) to be similar for treated and control groups, using techniques like Maximum Mean Discrepancy (MMD).
Beyond TARNet and DragonNet, the flexibility of neural networks has led to other architectures:
These models often address more complex scenarios or make different assumptions, representing the ongoing research in applying deep learning to causality.
Successfully applying these deep learning models requires careful attention to several aspects:
Deep learning approaches offer distinct benefits and drawbacks:
Advantages:
Disadvantages:
Deep learning methods represent a powerful set of tools for causal effect estimation in high dimensions, especially when dealing with complex data structures and non-linearities where methods like DML or Causal Forests might be less suitable. However, their successful application demands significant expertise in both deep learning and causal inference principles, along with careful implementation and validation.
© 2025 ApX Machine Learning