Pure forecasting comparators for treatment-effect counterfactuals
mtgpax is the causal MTGP implementation. dynamax is a useful dependency for non-causal forecasting baselines: state-space models fit only to admissible pre-treatment or control information, then forecast the treated post-treatment block as a counterfactual.
The local source checkout is in ref/dynamax/. The dependency is also installed from PyPI through pyproject.toml.
Dynamax API Surface
Dynamax is a JAX library for probabilistic state-space models. Its main families are:
hidden Markov models,
linear Gaussian state-space models,
nonlinear Gaussian state-space models,
generalized Gaussian state-space models with non-Gaussian emissions.
The baseline workhorse should be dynamax.linear_gaussian_ssm.LinearGaussianSSM. Its constructor is:
For emissions_pre with shape (T0, emission_dim), y_means has shape (H, emission_dim). These are baseline forecasts of \(Y_{it}(0)\) for the held-out post-treatment periods.
Counterfactual Forecasting Setup
Let treated units be \(\mathcal{T}\) and adoption times be \(T_i\). A pure forecasting baseline must not condition on treated post-treatment observations:
\[
\mathcal{C}_{\text{forecast}}
=
\{(i,t): i \in \mathcal{T}, t \le T_i\}
\cup
\{(i,t): i \notin \mathcal{T}, t \le T_{\text{train}}\}.
\]
For a treated unit \(i\), the forecast horizon is:
This is the cleanest leakage-free baseline. It asks whether the treated unit’s own history is enough to forecast its post-treatment path.
Recommended defaults:
emission_dim = 1 for a single outcome,
state_dim in {2, 4, 8} selected by pre-treatment rolling-origin validation,
fit_em(..., num_iters=100) initially,
Gaussian observations on transformed rates, or counts after a variance-stabilizing transform.
Batched Unit-Level LGSSM
Dynamax fit_em accepts batched emissions, so we can fit shared parameters over many unit sequences:
\[
\{Y_{i,1:T_{\text{train}}}: i \in \mathcal{I}_{\text{train}}\}
\longrightarrow
\widehat{\theta}.
\]
Then each treated unit is filtered on its own pre-treatment sequence and forecast forward with the shared \(\widehat{\theta}\).
This baseline uses cross-unit information without using the MTGP unit kernel. It is a useful comparator for asking whether a generic state-space forecast already captures most of the signal.
Multivariate Panel LGSSM
Use emission_dim = N_{\text{controls}} + K and treat a full panel slice as one multivariate emission vector. For simultaneous adoption: