Dynamax Forecasting Baselines for MTGPAX

Pure forecasting comparators for treatment-effect counterfactuals

mtgpax is the causal MTGP implementation. dynamax is a useful dependency for non-causal forecasting baselines: state-space models fit only to admissible pre-treatment or control information, then forecast the treated post-treatment block as a counterfactual.

The local source checkout is in ref/dynamax/. The dependency is also installed from PyPI through pyproject.toml.

Dynamax API Surface

Dynamax is a JAX library for probabilistic state-space models. Its main families are:

  • hidden Markov models,
  • linear Gaussian state-space models,
  • nonlinear Gaussian state-space models,
  • generalized Gaussian state-space models with non-Gaussian emissions.

The baseline workhorse should be dynamax.linear_gaussian_ssm.LinearGaussianSSM. Its constructor is:

LinearGaussianSSM(
    state_dim,
    emission_dim,
    input_dim=0,
    has_dynamics_bias=True,
    has_emissions_bias=True,
)

The common workflow is:

from dynamax.linear_gaussian_ssm import LinearGaussianSSM

model = LinearGaussianSSM(state_dim=4, emission_dim=1)
params, props = model.initialize(key)
params, lls = model.fit_em(params, props, emissions_pre, num_iters=100)

state_means, state_covs, y_means, y_covs = model.forecast(
    params,
    emissions=emissions_pre,
    num_forecast_timesteps=H,
)

For emissions_pre with shape (T0, emission_dim), y_means has shape (H, emission_dim). These are baseline forecasts of \(Y_{it}(0)\) for the held-out post-treatment periods.

Counterfactual Forecasting Setup

Let treated units be \(\mathcal{T}\) and adoption times be \(T_i\). A pure forecasting baseline must not condition on treated post-treatment observations:

\[ \mathcal{C}_{\text{forecast}} = \{(i,t): i \in \mathcal{T}, t \le T_i\} \cup \{(i,t): i \notin \mathcal{T}, t \le T_{\text{train}}\}. \]

For a treated unit \(i\), the forecast horizon is:

\[ H_i = T - T_i. \]

The baseline counterfactual is:

\[ \widehat{Y}_{it}^{\text{Dynamax}}(0) = \mathbb{E}_{\widehat{\theta}} \left[ Y_{it} \mid Y_{i,1:T_i} \right], \qquad t > T_i. \]

The corresponding effect estimate is:

\[ \widehat{\tau}_{it}^{\text{Dynamax}} = Y_{it} - \widehat{Y}_{it}^{\text{Dynamax}}(0). \]

This is intentionally less causal than MTGPAX: it is a forecasting comparator, not a design-based identification strategy.

Baseline Variants

Per-Treated-Unit LGSSM

Fit one LGSSM per treated unit using only that unit’s pre-treatment series:

\[ Y_{i,1:T_i} \longrightarrow \widehat{Y}_{i,T_i+1:T}(0). \]

This is the cleanest leakage-free baseline. It asks whether the treated unit’s own history is enough to forecast its post-treatment path.

Recommended defaults:

  • emission_dim = 1 for a single outcome,
  • state_dim in {2, 4, 8} selected by pre-treatment rolling-origin validation,
  • fit_em(..., num_iters=100) initially,
  • Gaussian observations on transformed rates, or counts after a variance-stabilizing transform.

Batched Unit-Level LGSSM

Dynamax fit_em accepts batched emissions, so we can fit shared parameters over many unit sequences:

\[ \{Y_{i,1:T_{\text{train}}}: i \in \mathcal{I}_{\text{train}}\} \longrightarrow \widehat{\theta}. \]

Then each treated unit is filtered on its own pre-treatment sequence and forecast forward with the shared \(\widehat{\theta}\).

This baseline uses cross-unit information without using the MTGP unit kernel. It is a useful comparator for asking whether a generic state-space forecast already captures most of the signal.

Multivariate Panel LGSSM

Use emission_dim = N_{\text{controls}} + K and treat a full panel slice as one multivariate emission vector. For simultaneous adoption:

\[ \mathbf{y}_t = \left(Y_{1t}, \ldots, Y_{Nt}\right)^\top, \qquad t \le T_0. \]

After fitting on pre-treatment panel vectors, forecast the whole vector over \(t > T_0\) and retain the treated coordinates.

This is closer to a forecasting analogue of synthetic control, but it can be fragile when \(N\) is large relative to \(T_0\).

Count and Regime Baselines

For count outcomes, useful follow-ups are:

  • PoissonHMM for regime-switching count levels,
  • generalized Gaussian SSMs for non-Gaussian emission models,
  • transforms such as \(\log(1 + Y_{it})\) when a Gaussian LGSSM is good enough as a baseline.

These should be labeled as forecasting baselines. They should not replace the MTGPAX causal estimand.

Implementation Contract

The baseline interface should mirror the causal estimator:

baseline = DynamaxForecastBaseline(
    model="lgssm",
    state_dim=4,
    variant="per_treated_unit",
)

fit = baseline.fit(
    y=outcomes,
    treated_units=[0, 3, 8],
    treatment_time=10,
    control_mask=control_mask,
)

effects = fit.predict_effects()

Required returned quantities:

  • counterfactual_mean: array over missing_mask,
  • counterfactual_cov or marginal standard errors where available,
  • effect_mean = y_observed_post - counterfactual_mean,
  • rolling-origin validation metrics from pre-treatment periods,
  • metadata recording which observations were used in fitting.

Minimal Smoke Test

import jax
import jax.numpy as jnp
from dynamax.linear_gaussian_ssm import LinearGaussianSSM

key = jax.random.PRNGKey(0)
T0 = 30
H = 5

t = jnp.arange(T0 + H)
y = jnp.sin(t / 5.0)[:, None] + 0.05 * jax.random.normal(key, (T0 + H, 1))
y_pre = y[:T0]

model = LinearGaussianSSM(state_dim=2, emission_dim=1)
params, props = model.initialize(key)
params, lls = model.fit_em(params, props, y_pre, num_iters=5, verbose=False)

_, _, y_forecast_mean, y_forecast_cov = model.forecast(
    params,
    emissions=y_pre,
    num_forecast_timesteps=H,
)

print("forecast_mean_shape:", y_forecast_mean.shape)
print("forecast_cov_shape:", y_forecast_cov.shape)
print("last_em_log_prob:", float(lls[-1]))
forecast_mean_shape: (5, 1)
forecast_cov_shape: (5, 1, 1)
last_em_log_prob: 9.30750846862793