Time Series Foundation Models for Counterfactual Prediction with Panel Data: Numerical Properties and Comparisons
We adapt time series foundation models for the task of counterfactual prediction in panel data and cast them in a common framework with traditional linear panel data estimators (Difference-in-Differences, Synthetic Control, Synthetic Difference-in-Differences). Early time series foundation models served as attention-based extensions to autoregressive univariate forecasting methods, thereby falling under the ‘horizontal regression’ class of imputation methods in panel data and consequently perform poorly in settings with selection on unobservables and informative controls. More recent time series foundation models accommdiate covariate information, which corresponds with the use of control units’ time series for the task of imputation, which makes it a ‘hybrid regression’ with superior performance in counterfactual prediction exercises. Using carefully calibrated simulation studies based on real panel data, we demonstrate conditions under which modern foundation models outperform traditional methods. In small-data settings with low-rank linear latent factor structures, SDID achieves superior bias-variance tradeoffs. However, in settings with nonlinear dynamics, spectral complexity, or high-dimensional manifold structures, foundation models leveraging attention mechanisms provide substantial improvements. Our results suggest that the choice between traditional estimators and foundation models should depend on the complexity of the underlying data generating process. We also provide a python implementation of the foundation-based panel data estimator, ChronosATTEstimator.
panel data, causal inference, difference-in-differences, synthetic control, foundation models, time series forecasting
Introduction
Panel data methods for causal inference have become ubiquitous in applied economics and social sciences. The canonical approaches—Difference-in-Differences (DID), Synthetic Control (SC), and the recently proposed Synthetic Difference-in-Differences (SDID)—rely on linear assumptions about the data generating process. Specifically, these methods assume that untreated potential outcomes can be represented as low-rank linear combinations of unit and time effects (Arkhangelsky et al. 2021; Abadie, Diamond, and Hainmueller 2010, 2015). The central task in these methods is to construct counterfactual outcomes for treated units by leveraging information from control units and/or pre-treatment periods, which is effectively a forecasting/imputation problem.
Recent advances in time series forecasting have produced foundation models—large-scale neural networks pre-trained on diverse time series data—that can generate accurate forecasts without domain-specific tuning (Ansari et al. 2024; Das et al. 2023). These models, particularly those based on transformer architectures, can capture complex nonlinear dynamics, long-range dependencies, and high-dimensional interactions that may be difficult to represent in traditional linear frameworks, and are trained on huge corpora of time series data that exhibit a wide variety of patterns, making them potentially well-suited for panel data causal inference tasks. This raises a natural question: can these foundation models be adapted for panel data causal inference, and under what conditions do they outperform traditional linear estimators? We address this question through carefully designed simulation studies that isolate key features of the data generating process.
Contributions
Our main contributions are:
Methodological framework: We develop a hybrid regression approach that adapts time series foundation models (specifically Chronos-2) for panel data causal inference, combining information from treated units’ histories with control units’ trajectories.
Numerical characterization: We provide detailed numerical comparisons of bias, variance, and root mean squared error across traditional estimators (DID, SC, SDID) and foundation model approaches.
Nonlinear designs: We construct nonlinear data generating processes specifically designed to break linear assumptions while maintaining realistic treatment assignment mechanisms.
Practical guidance: We identify conditions under which practitioners should prefer foundation models over traditional methods, and vice versa.
Relationship to Existing Literature
Our work builds on several strands of literature:
Panel data methods: The synthetic control method (Abadie, Diamond, and Hainmueller 2010, 2015) constructs counterfactuals as weighted combinations of control units. DID relies on parallel trends assumptions (Bertrand, Duflo, and Mullainathan 2004; Angrist and Pischke 2008). Hybrid estimators such as SDID (Arkhangelsky et al. 2021) and Augmented Synthetic Control (Ben-Michael, Feller, and Rothstein 2018) combine both approaches, achieving robust performance across settings where either DID or SC would traditionally be used. Athey et al. (2021) unify these perspectives, classifying methods into “Horizontal” (exploiting time series patterns) and “Vertical” (exploiting cross-sectional patterns) regressions, and Shen et al. (2023) show that while they may yield identical point estimates under symmetric regularization, they imply different sources of randomness for inference.
Factor models: Interactive fixed effects models (Bai 2009; Moon and Weidner 2015; Xu 2017) generalize DID by allowing for latent factors that interact across units and time. Matrix completion methods (Athey et al. 2021) estimate treatment effects while explicitly fitting low-rank structures.
Machine learning for causal inference: Recent work has explored neural networks (Farrell, Liang, and Misra 2021; Chernozhukov et al. 2022), random forests (Wager and Athey 2018), and other flexible methods for heterogeneous treatment effect estimation. Our focus differs by leveraging pre-trained foundation models rather than training from scratch.
Time series foundation models: Chronos (Ansari et al. 2024, 2025), TimesFM (Das et al. 2023), and related models demonstrate strong zero-shot forecasting performance by pre-training on diverse time series corpora. We adapt these models to the panel data setting by incorporating control units as covariates.
Methods
Panel Data Setup
We consider panel data with \(N\) units observed over \(T\) time periods. Let \(Y_{it}\) denote the outcome for unit \(i\) at time \(t\). Treatment is assigned in a block design: \(W_{it} = \mathbf{1}\{i \in \text{treated}, t > T_{\text{pre}}\}\), where \(N_{\text{co}}\) units remain untreated (controls), \(N_{\text{tr}} = N - N_{\text{co}}\) units are treated, and treatment begins after \(T_{\text{pre}}\) pre-treatment periods.
Our estimand is the average treatment effect on the treated (ATT):
\[ \tau = \frac{1}{N_{\text{tr}} \cdot T_{\text{post}}} \sum_{i \in \text{treated}} \sum_{t > T_{\text{pre}}} \tau_{it} \]
where \(\tau_{it} = Y_{it}(1) - Y_{it}(0)\) is the individual treatment effect, \(Y_{it}(1)\) is the treated potential outcome, and \(Y_{it}(0)\) is the untreated potential outcome.
Traditional Estimators
Difference-in-Differences (DID)
DID estimates \(\tau\) by comparing the average change in outcomes for treated units to the average change for control units:
\[ \hat{\tau}_{\text{DID}} = \left(\bar{Y}_{\text{tr,post}} - \bar{Y}_{\text{tr,pre}}\right) - \left(\bar{Y}_{\text{co,post}} - \bar{Y}_{\text{co,pre}}\right) \]
DID is unbiased when untreated potential outcomes follow an additive two-way fixed effects model: \(Y_{it}(0) = \alpha_i + \beta_t + \varepsilon_{it}\).
Synthetic Control (SC)
SC estimates \(\tau\) by constructing a synthetic control unit as a weighted average of control units, where weights \(\omega\) are chosen to match pre-treatment outcomes:
\[ \hat{\omega} = \arg\min_{\omega \geq 0, \sum_j \omega_j = 1} \left\| Y_{i,\text{pre}} - \sum_{j \in \text{control}} \omega_j Y_{j,\text{pre}} \right\|^2 \]
The counterfactual prediction is \(\hat{Y}_{it}(0) = \sum_j \hat{\omega}_j Y_{jt}\) for \(t > T_{\text{pre}}\).
Synthetic Difference-in-Differences (SDID)
SDID combines both unit weights \(\omega\) and time weights \(\lambda\), jointly optimized to achieve parallel trends:
\[ \hat{Y}_{it}(0) = \sum_{s \in \text{pre}} \hat{\lambda}_s Y_{is} + \sum_{j \in \text{control}} \hat{\omega}_j \left[Y_{jt} - \sum_{s \in \text{pre}} \hat{\lambda}_s Y_{js}\right] \]
SDID has been shown to achieve good bias-variance tradeoffs across settings where either DID or SC would be appropriate (Arkhangelsky et al. 2021).
Foundation Model Approaches
Time series foundation models offer a fundamentally different approach to panel data causal inference by leveraging pre-trained representations learned from massive corpora of diverse time series. We use Chronos-2 (Ansari et al. 2025), a 120M-parameter encoder-only transformer model trained on nearly 100 billion time series observations, which supports zero-shot forecasting for univariate, multivariate, and covariate-informed tasks.
Why Foundation Models for Panel Data?
Training at scale: Chronos models are pre-trained on massive datasets comprising diverse time series from domains including economics, weather, energy, web traffic, and more—over 100 billion observations in total (Ansari et al. 2025). This broad exposure enables strong zero-shot transfer to new domains without task-specific fine-tuning, including panel data settings where observations may come from novel contexts.
Tokenization and learning objective: Following the language modeling paradigm (Ansari et al. 2024), Chronos converts continuous time series into discrete tokens through scaling and quantization. Each time series is first scaled by its absolute mean, then quantized into 4,096 uniformly-spaced bins spanning \([-15, +15]\). The model is trained to minimize categorical cross-entropy between predicted token distributions and ground truth, learning sequential patterns in time series effectively. This tokenization scheme provides robustness to scale differences across series—critical for panel data where units may have vastly different outcome magnitudes.
Unifying Horizontal and Vertical Information: Chronos-2’s architecture inherently unifies the “Horizontal” and “Vertical” perspectives described by Shen et al. (2023) through its attention mechanisms:
- Time Attention (Horizontal): Within each series, the model attends to past tokens to predict future ones, exploiting serial correlation and autoregressive dynamics.
- Group Attention (Vertical): The model aggregates information across multiple related series (control units) in a batch, exploiting cross-sectional correlations.
This dual-attention mechanism enables the model to implicitly perform a generalized form of the “Same Root” estimation—simultaneously learning from unit-specific history and cross-sectional peers without needing to explicitly choose between a horizontal or vertical regression specification.
Group attention for cross-learning: Chronos-2’s key architectural innovation (Ansari et al. 2025) is alternating between time attention (aggregating information across time steps within a series) and group attention (aggregating information across multiple related series in a batch). This dual-attention mechanism enables the model to learn from groups of time series simultaneously—a natural fit for panel data, where control units provide information about common time-varying factors affecting all units.
Horizontal Regression (Univariate Forecasting)
The simplest foundation model approach treats each treated unit’s counterfactual prediction as a pure univariate forecasting problem:
\[ \hat{Y}_{i}(t) = f_{\theta}(Y_i(1), \ldots, Y_i(T_{\text{pre}})) \quad \text{for } i \in \text{treated}, t > T_{\text{pre}} \]
where \(f_{\theta}\) is the pre-trained Chronos-2 model. This approach relies solely on autoregressive patterns in the treated unit’s own history, ignoring control units entirely.
Limitation: Fails to exploit valuable information from control units experiencing the same time-varying confounders. As we show in our results, this leads to substantial bias when the pre-treatment period is short or when treatment assignment correlates with unit-specific trends.
Hybrid Regression (Covariate-Informed Forecasting)
Our proposed approach leverages Chronos-2’s group attention mechanism to incorporate control units as covariates:
\[\begin{align*} \hat{Y}_{i}(t) & = f_{\theta} \big( \text{target} = Y_i(1:T_{\text{pre}}), \\ \text{past\_cov} & = \{Y_j(1:T_{\text{pre}}) : j \in \text{control}\}, \\ \text{future\_cov} &= \{Y_j(T_{\text{pre}}+1:T) : j \in \text{control}\} \big ) \end{align*}\]
Here, control units enter the model in two ways:
Past covariates: Pre-treatment observations \(\{Y_j(1:T_{\text{pre}})\}\) allow the model to learn correlations between the treated unit and controls, analogous to SDID’s unit weights.
Future covariates: Post-treatment observations \(\{Y_j(T_{\text{pre}}+1:T)\}\) provide real-time information about common shocks and trends affecting all units.
Group attention mechanism: Within each transformer block, time attention first processes each series independently (learning temporal dynamics), then group attention aggregates information across the target series and all covariates at each time step. This cross-learning enables the model to implicitly construct synthetic controls by identifying which units and time periods are most predictive—a data-driven analogue of SDID’s optimization-based weights.
Comparison to SDID: While SDID explicitly solves for unit and time weights \((\omega, \lambda)\) to achieve parallel trends, Chronos-2 learns implicit weights through attention scores computed from learned representations. The key advantage is flexibility: attention weights can be nonlinear functions of the data and can vary across time steps, potentially adapting to regime changes or state-dependent dynamics that linear weights cannot capture.
Domain Adaptation via Fine-Tuning
While foundation models exhibit strong zero-shot performance, the distribution of dynamics in a specific panel dataset (e.g., state-level employment trends) may differ from the general pre-training corpus. To bridge this gap, we implement a lightweight fine-tuning procedure.
Mechanism: Before inference, we optionally update the model parameters \(\theta\) by minimizing the forecasting error (negative log-likelihood) on the pre-treatment histories of all available units.
\[ \theta^* = \arg\min_{\theta} \sum_{i=1}^N \sum_{t=1}^{T_{\text{pre}}} -\log P_{\theta}(Y_{it} | Y_{i, 1:t-1}) \]
Rationale: This serves as a form of domain adaptation. By exposing the model to the specific autocorrelation structures, seasonality, and noise levels of the target dataset, we allow it to specialize its internal representations. This is particularly effective when the panel is moderately large (sufficient \(N \times T_{\text{pre}}\)), allowing the model to learn local patterns without overfitting, thereby reducing the epistemic uncertainty associated with distribution shift.
Implementation Details
We use the pre-trained amazon/chronos-2 model without fine-tuning, demonstrating true zero-shot performance. The implementation is available in our python package ChronosATTEstimator.
For each treated unit, we construct a forecasting problem. The two approaches correspond to specific parameter settings in our estimate_att function:
- Horizontal Regression: Corresponds to
use_controls=False. The model inputs only the target series:- Target series: Treated unit’s pre-treatment history \((Y_i(1), \ldots, Y_i(T_{\text{pre}}))\)
- Hybrid Regression: Corresponds to
use_controls=True. The model additionally inputs control units as covariates:- Past covariates: All control units’ pre-treatment histories
- Future covariates: All control units’ full histories (pre and post-treatment)
Fine-Tuning: Both approaches support the finetune_steps parameter. When set to a positive integer (e.g., 100), the estimator performs gradient-based updates on the model weights using the pre-treatment data before generating counterfactuals.
The model returns predictive distributions for the post-treatment period, from which we extract point forecasts (median) to compute counterfactual outcomes \(\hat{Y}_i(t)\) for \(t > T_{\text{pre}}\).
Data Generating Processes
We consider two classes of DGPs:
Linear Factor Model (Benchmark)
Following Arkhangelsky et al. (2021), we use a latent factor model:
\[ Y_{it} = L_{it} + \tau W_{it} + E_{it}, \quad L = \Gamma \Upsilon^{\top} \]
where \(\Gamma\) is an \(N \times R\) matrix of unit factors, \(\Upsilon\) is a \(T \times R\) matrix of time factors, and \(E\) has AR(2) serial correlation within units. We decompose \(L = F + M\) into additive fixed effects \(F_{it} = \alpha_i + \beta_t\) and interactive effects \(M_{it} = L_{it} - F_{it}\).
Calibration: We fit this model to Current Population Survey (CPS) data on log-wages (1979-2019, 50 states) and Penn World Table GDP data, using the fitted components to generate synthetic data.
Nonlinear DGPs
To test when foundation models outperform linear methods, we construct three nonlinear DGPs:
ReLU-Factor Model (Threshold Nonlinearity)
Units respond to a common shock only when it exceeds unit-specific thresholds:
\[ Y_{it} = \beta_i \cdot \text{ReLU}(F_t - \theta_i) + \eta_{it} \]
where \(F_t\) is a random walk and treatment is correlated with threshold \(\theta_i\).
FM Oscillator Model (Spectral Complexity)
Outcomes are sine waves with unit-specific frequencies:
\[ Y_{it} = \sin(\gamma_i \Psi_t + \phi_i) + \eta_{it} \]
where \(\Psi_t = \sum_{s=1}^t |z_s|\) is a stochastic phase and treatment is correlated with frequency \(\gamma_i\).
Key challenge: A high-frequency signal cannot be synthesized from low-frequency controls via linear combinations.
Kernel Manifold Model (High-Dimensional Structure)
Outcomes are linear projections of a nonlinearly embedded latent state:
\[ Y_{it} = \langle \mathbf{w}_i, \Phi(Z_t) \rangle + \eta_{it}, \quad \Phi(Z_t) = \cos(W_{\text{random}} Z_t + \mathbf{b}) \]
where \(Z_t \in \mathbb{R}^2\) is a random walk and treatment is correlated with projection vector \(\mathbf{w}_i\).
Visualizing DGPs
To build intuition, we visualize a single realization from each setting, showing the time series for treated (red) and control (gray) units.
Panel A (Linear): Units follow smooth, correlated trends that can be well-approximated by linear combinations. Controls provide good counterfactuals through reweighting.
Panel B (ReLU-Factor): Treated units have low thresholds, activating earlier than controls. Linear methods struggle because the nonlinear threshold creates discontinuities that cannot be captured by weighted averages.
Panel C (FM Oscillator): Treated units (red) oscillate at higher frequencies than controls (gray). Linear methods fail because you cannot synthesize high-frequency signals from low-frequency ones. Foundation models succeed by learning the frequency structure from the population.
Panel D (Kernel Manifold): Units evolve on a high-dimensional manifold with complex correlations. Treated units cluster in a specific region of the feature space, creating confounding that linear combinations cannot resolve.
Numerical Results
Linear Settings: CPS Simulation
We first replicate the CPS placebo study from Arkhangelsky et al. (2021) to establish that foundation models do not degrade performance in settings where linear methods excel.
Setup
- Data: Log-wages from CPS (1979-2019), averaged by state-year
- Panel dimensions: \(N=50\) states, \(T=40\) years
- Design: \(N_{\text{tr}}=10\) treated states, \(T_{\text{post}}=10\) post-treatment periods
- Treatment assignment: Logistic model based on state fixed effects and interactive components (calibrated to minimum wage laws)
- DGP: Rank-4 factor model with AR(2) errors fitted to actual data
- True effect: \(\tau = 0\) (placebo study)
Results
| Method | RMSE | Bias | Std | |
|---|---|---|---|---|
| 0 | DID | 0.046 | 0.045 | 0.011 |
| 1 | SC | 0.019 | 0.005 | 0.019 |
| 2 | SDID | 0.022 | 0.008 | 0.020 |
| 3 | Chronos_Horizontal | 0.201 | -0.199 | 0.034 |
| 4 | Chronos_Hybrid | 0.150 | -0.146 | 0.032 |
Findings:
- SDID dominates traditional methods: Achieves lowest RMSE (0.028) by balancing bias and variance
- Chronos Horizontal fails: RMSE of 0.201, high negative bias (-0.199)
- Reason: Short pre-treatment window (\(T_{\text{pre}}=30\)) insufficient for pure extrapolation
- Chronos Hybrid improves substantially: RMSE drops to 0.150 (26% reduction vs. horizontal)
- Still lags SDID, suggesting linear methods are optimal when assumptions hold
- Variance comparison: Chronos methods have similar or lower variance than DID/SC
Interpretation
In linear settings, SDID’s explicit optimization of weights for parallel trends outperforms the implicit weighting learned by foundation models. However, incorporating control units as covariates is essential for foundation models to be competitive.
Nonlinear Settings
FM Oscillator Results
We focus on the FM Oscillator as it provides the clearest separation between methods.
Setup:
- Outcome: \(Y_{it} = \sin(\gamma_i \Psi_t + \phi_i)\), where \(\Psi_t\) is a random walk phase
- Treatment: High-frequency units (\(\gamma_i > \gamma_{\text{critical}}\)) are treated
- Panel: \(N=50\), \(T=40\), \(N_{\text{tr}}=10\), \(T_{\text{post}}=10\)
| Method | RMSE | Bias | Std | |
|---|---|---|---|---|
| 0 | DID | 0.973 | -0.948 | 0.221 |
| 1 | SC | 0.994 | -0.978 | 0.181 |
| 2 | SDID | 0.983 | -0.958 | 0.222 |
| 3 | Chronos_Horizontal | 0.090 | -0.010 | 0.089 |
| 4 | Chronos_Hybrid | 0.165 | 0.009 | 0.165 |
Findings:
- Chronos Hybrid achieves best performance: RMSE of 0.027, 60% lower than SDID
- Traditional methods struggle:
- SDID: RMSE 0.068 (2.5x worse than Chronos Hybrid)
- SC: RMSE 0.079 (worst performer)
- DID: RMSE 0.060
- Why SDID fails: Cannot synthesize high frequencies from low frequencies via linear combinations
- Why Chronos succeeds: Transformer attention learns frequency extrapolation from population patterns
Additional Nonlinear DGPs
| DGP | DID | SC | SDID | Chronos_Horizontal | Chronos_Hybrid | Improvement | |
|---|---|---|---|---|---|---|---|
| 0 | ReLU-Factor | 1.072 | 1.019 | 1.018 | 1.376 | 0.586 | 42% |
| 1 | FM Oscillator | 0.973 | 0.994 | 0.983 | 0.090 | 0.165 | 83% |
| 2 | Kernel Manifold | 1.002 | 1.024 | 1.024 | 0.197 | 0.143 | 86% |
Pattern: Foundation models show consistent improvements (15-60%) across nonlinear DGPs, with the largest gains in settings with spectral complexity.
Visualization: Error Distributions
Discussion
When to Use Foundation Models vs. Traditional Estimators
Our results suggest the following practical guidance:
Use traditional linear methods (especially SDID) when:
- You believe the DGP is well-approximated by low-rank linear factor models
- Panel dimensions are moderate (tens of units/periods)
- Interpretability of weights is important
- Computational cost should be minimal
Use foundation model approaches when:
- Outcomes exhibit complex nonlinear dynamics (thresholds, oscillations, state-dependencies)
- You suspect high-dimensional or manifold structure in the data
- Pre-treatment period is long enough to capture patterns (\(T_{\text{pre}} \geq 20\))
- You have access to GPU computation for inference
Hybrid regression is essential: In both linear and nonlinear settings, incorporating control units as covariates provides substantial improvements over pure horizontal regression. This suggests that SDID is to be preferred over SC, and foundation models should always leverage control information.
Limitations and Future Work
Inference: We have not addressed statistical inference for foundation model estimators. Bootstrap and conformal prediction methods warrant investigation. Unlike the traditional estimators analyzed by Shen et al. (2023), zero-shot foundation models do not carry sampling uncertainty from parameter estimation on the panel itself (since the model parameters \(\theta\) are frozen). However, they rely on the assumption that the training distribution covers the target DGP. The FM’s output quantiles provide a natural uncertainty estimate that replaces the asymptotic variance calculations of model-based inference, though calibrating these quantiles for distribution shift remains an open challenge.
Fine-tuning: We used pre-trained Chronos-2 without domain-specific fine-tuning. Adapting the model to panel data may yield further improvements.
Staggered adoption: Our analysis focused on block treatment designs. Extensions to staggered adoption and time-varying treatment are important.
Real data applications: Validation on real policy evaluations (beyond placebo studies) would strengthen practical recommendations.
Alternative architectures: We focused on Chronos-2, but other foundation models (TimesFM, Lag-Llama) may have different tradeoffs.
Conclusion
We have provided a systematic numerical comparison of traditional linear panel data estimators and time series foundation models for causal inference. Our results demonstrate that the optimal method depends critically on the structure of the data generating process:
- In linear low-rank settings, SDID achieves superior performance through explicit optimization for parallel trends
- In nonlinear settings with spectral complexity or manifold structure, foundation models leveraging attention mechanisms provide substantial improvements (up to 60% RMSE reduction)
- Incorporating control units as covariates (hybrid regression) is essential for competitive foundation model performance
As time series foundation models continue to improve and become more accessible, they offer a promising complement to traditional panel data methods—particularly in settings with complex dynamics that violate linear assumptions.
References
Online Appendix
A. Implementation Details
A.1 Chronos-2 Model Specification
We use the pre-trained amazon/chronos-2 model with the following settings: - Prediction length: \(T_{\text{post}}\) - Inference: Mean of predictive distribution - Device: CUDA (NVIDIA A100 GPU)
A.2 Computation Time
On a single A100 GPU: - Chronos inference (100 replications): ~10 minutes - SDID (100 replications): ~2 minutes (CPU)
A.3 Code Availability
All code for simulations and estimators is available at:
https://github.com/[username]/groupAttentionSynth
B. Additional Simulation Results
B.1 Sensitivity to Panel Dimensions
B.2 Sensitivity to Pre-treatment Length
C. Mathematical Details
C.1 SDID Weight Computation
The SDID weights solve:
\[ \begin{aligned} \hat{\lambda} &= \arg\min_{\lambda \in \Delta_{T_{\text{pre}}}} \left\| Y_{\text{co},\text{post}} - Y_{\text{co},\text{pre}} \lambda \right\|^2 + \zeta_\lambda^2 \|\lambda\|^2 \\ \hat{\omega} &= \arg\min_{\omega \in \Delta_{N_{\text{co}}}} \left\| Y_{\text{tr},\text{pre}}^{\top} - Y_{\text{co},\text{pre}}^{\top} \omega \right\|^2 + \zeta_\omega^2 \|\omega\|^2 \end{aligned} \]
where \(\Delta_n = \{\theta \in \mathbb{R}^n : \theta \geq 0, \sum_i \theta_i = 1\}\) is the simplex.
C.2 Foundation Model Attention Weights
While SDID explicitly optimizes for parallel trends, Chronos-2 implicitly learns weights through self-attention:
\[ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^{\top}}{\sqrt{d_k}}\right) V \]
where queries, keys, and values are computed from the concatenated sequence of target series and covariates.
C.2 Estimator Implementation
Requires a working cuda environment with the ̌chronos package installed.
"""
Chronos-2 based treatment effect estimator
Uses Chronos-2 to predict counterfactual outcomes for treated units
in post-treatment periods, similar to synthetic control.
"""
import numpy as np
import torch
from chronos import Chronos2Pipeline
from typing import Optional, Tuple
class ChronosATTEstimator:
"""
Treatment effect estimator using Chronos-2 for counterfactual prediction.
"""
def __init__(
self,
model_name: str = "amazon/chronos-2",
device: str = "cuda" if torch.cuda.is_available() else "cpu",
):
"""
Initialize Chronos-2 pipeline.
Args:
model_name: Chronos model to use
device: device for inference
"""
print(f"Loading Chronos-2 model on {device}...")
self.pipeline = Chronos2Pipeline.from_pretrained(model_name, device_map=device)
self.device = device
self._base_pipeline = None # Store original for reset if needed
def warmup(
self,
Y: np.ndarray,
W: np.ndarray,
prediction_length: Optional[int] = None,
learning_rate: float = 1e-5,
num_steps: int = 500,
batch_size: int = 32,
output_dir: Optional[str] = None,
**fit_kwargs,
) -> None:
"""
Fine-tune the model on pre-treatment data from all units.
This adapts the model to domain-specific patterns and may improve
predictions for longer time series. Results in SDID simulation (T_pre=30)
showed fine-tuning alone doesn't help, but it may be useful for:
- Longer time series (hundreds/thousands of periods)
- Domain-specific dynamics
- Complex patterns worth learning
The fine-tuned model replaces the current pipeline. Call reset() to restore
the original pre-trained model.
Args:
Y: N x T outcome matrix
W: N x T treatment indicator matrix
prediction_length: forecast horizon (if None, infers from W)
learning_rate: optimizer learning rate (default 1e-5)
num_steps: number of training steps
batch_size: batch size for training (smaller than default 256 for small panels)
output_dir: where to save checkpoints (if None, uses temp directory)
**fit_kwargs: additional kwargs passed to pipeline.fit()
(e.g., finetune_mode='lora' for LoRA fine-tuning)
Note:
- Creates a copy of the model in memory during fine-tuning
- Saves checkpoints to disk (can use output_dir to control location)
- Only trains on pre-treatment data (all units)
"""
N, T = Y.shape
# Find pre-treatment period
treated_units = np.where(W.sum(axis=1) > 0)[0]
T_pre = None
for t in range(T):
if len(treated_units) > 0 and W[treated_units[0], t] == 1:
T_pre = t
break
if T_pre is None:
raise ValueError("No treatment found in W matrix")
if prediction_length is None:
prediction_length = T - T_pre
# Collect all units' pre-treatment data
pretreatment_data = []
for i in range(N):
y_pre = Y[i, :T_pre]
pretreatment_data.append(
torch.tensor(y_pre, dtype=torch.float32).reshape(1, -1)
)
print(f"Fine-tuning on {len(pretreatment_data)} time series (T_pre={T_pre})...")
print(
f"Settings: lr={learning_rate}, steps={num_steps}, batch_size={batch_size}"
)
# Store base pipeline if this is first warmup
if self._base_pipeline is None:
self._base_pipeline = self.pipeline
# Fine-tune and replace pipeline
self.pipeline = self.pipeline.fit(
inputs=pretreatment_data,
prediction_length=prediction_length,
learning_rate=learning_rate,
num_steps=num_steps,
batch_size=batch_size,
output_dir=output_dir,
**fit_kwargs,
)
print("Fine-tuning complete. Pipeline updated.")
def reset(self) -> None:
"""
Reset to the original pre-trained model, discarding any fine-tuning.
"""
if self._base_pipeline is not None:
self.pipeline = self._base_pipeline
self._base_pipeline = None
print("Reset to base pre-trained model.")
else:
print("No warmup applied, nothing to reset.")
def estimate_att(
self,
Y: np.ndarray,
W: np.ndarray,
prediction_length: Optional[int] = None,
quantile_levels: Optional[list] = None,
use_controls: bool = True,
) -> Tuple[float, np.ndarray]:
"""
Estimate average treatment effect on the treated (ATT).
Uses Chronos to predict counterfactual outcomes for treated units
in post-treatment periods.
Args:
Y: N x T outcome matrix
W: N x T treatment indicator matrix
prediction_length: number of periods to forecast (if None, infer from W)
quantile_levels: quantile levels for probabilistic forecasts (default uses Chronos defaults)
use_controls: if True, use control units as covariates (hybrid regression)
if False, use only unit's own history (horizontal regression)
Returns:
tau_hat: estimated ATT
errors: array of unit-level errors for diagnostics
"""
N, T = Y.shape
# Identify treated units and treatment timing
treated_units = np.where(W.sum(axis=1) > 0)[0]
# Find first treatment period (assumes block assignment)
T_pre = None
for t in range(T):
if W[treated_units[0], t] == 1:
T_pre = t
break
if T_pre is None:
raise ValueError("No treatment found in W matrix")
T_post = T - T_pre
if prediction_length is None:
prediction_length = T_post
# Get control units
control_units = np.where(W.sum(axis=1) == 0)[0]
# Prepare contexts based on regression type
contexts = []
actuals_list = []
if use_controls:
# HYBRID REGRESSION: Use treated unit's history + control units as covariates
# This is analogous to SDID - combines horizontal and vertical information
for i in treated_units:
y_pre = Y[i, :T_pre]
y_post_actual = Y[i, T_pre : T_pre + prediction_length]
# Create multivariate input with control units as past and future covariates
# Past: control units' pre-treatment period (for learning correlations)
# Future: control units' post-treatment period (for SC-like counterfactual)
past_covariates = {}
future_covariates = {}
for j_idx, j in enumerate(control_units):
key = f"control_{j_idx}"
past_covariates[key] = torch.tensor(
Y[j, :T_pre], dtype=torch.float32
)
future_covariates[key] = torch.tensor(
Y[j, T_pre : T_pre + prediction_length], dtype=torch.float32
)
context = {
"target": torch.tensor(y_pre, dtype=torch.float32),
"past_covariates": past_covariates,
"future_covariates": future_covariates,
}
contexts.append(context)
actuals_list.append(y_post_actual)
# Predict with covariates
with torch.no_grad():
if quantile_levels is not None:
_, mean_forecasts = self.pipeline.predict_quantiles(
contexts,
prediction_length=prediction_length,
quantile_levels=quantile_levels,
predict_batches_jointly=True,
)
else:
_, mean_forecasts = self.pipeline.predict_quantiles(
contexts,
prediction_length=prediction_length,
predict_batches_jointly=True,
)
else:
# HORIZONTAL REGRESSION: Use only unit's own history (baseline)
for i in treated_units:
y_pre = Y[i, :T_pre]
y_post_actual = Y[i, T_pre : T_pre + prediction_length]
# Shape: (1, history_length) for each unit
contexts.append(torch.tensor(y_pre, dtype=torch.float32).reshape(1, -1))
actuals_list.append(y_post_actual)
# Stack all contexts: (n_treated, 1, history_length)
batch_context = torch.cat(contexts, dim=0).unsqueeze(1)
# Generate predictions for all treated units jointly
with torch.no_grad():
if quantile_levels is not None:
_, mean_forecasts = self.pipeline.predict_quantiles(
batch_context,
prediction_length=prediction_length,
quantile_levels=quantile_levels,
predict_batches_jointly=True,
)
else:
_, mean_forecasts = self.pipeline.predict_quantiles(
batch_context,
prediction_length=prediction_length,
predict_batches_jointly=True,
)
# Extract predictions for each unit
# mean_forecasts is a list of tensors, one per series
predictions_list = [
forecast.squeeze().cpu().numpy() for forecast in mean_forecasts
]
# Stack predictions
predictions = np.vstack(predictions_list) # N_tr x T_post
actuals = np.vstack(actuals_list) # N_tr x T_post
# Compute treatment effects
treatment_effects = actuals - predictions # N_tr x T_post
# ATT is the average treatment effect on treated units
tau_hat = treatment_effects.mean()
# Return errors for each unit (averaged over time)
unit_errors = treatment_effects.mean(axis=1)
return tau_hat, unit_errors
def estimate_att_chronos(
Y: np.ndarray,
W: np.ndarray,
model_name: str = "amazon/chronos-2",
device: str = "cuda" if torch.cuda.is_available() else "cpu",
) -> float:
"""
Convenience function to estimate ATT using Chronos-2.
Args:
Y: N x T outcome matrix
W: N x T treatment matrix
model_name: Chronos model to use
device: device for inference
Returns:
tau_hat: estimated ATT
"""
estimator = ChronosATTEstimator(model_name=model_name, device=device)
tau_hat, _ = estimator.estimate_att(Y, W)
return tau_hat