Replicating DJ Paper Code Fragments

CmdStanPy vs Stan C++ vs direct JAX/BlackJAX

This note replicates the central code fragments from Bob Carpenter’s DJ paper, especially the comparison between:

The original page is archived locally at ref/dj-paper/dj-paper.html and ref/dj-paper/dj-paper.md.

Takeaway for mtgpax

For mtgpax, the direct JAX path is attractive for full Bayes because it gives us:

  • differentiable tensor code,
  • PyTree parameter containers rather than serialized parameter vectors,
  • explicit transforms and Jacobian adjustments,
  • BlackJAX/NumPyro-style HMC/NUTS on JAX arrays,
  • an easier path to GPU/TPU acceleration than Stan.

Stan still wins on mature diagnostics, robust default NUTS behavior, and not making the user hand-write transforms. The DJ paper’s point is that the gap is not philosophical: Stan compiles a constrained model to an unconstrained differentiable log density; we can write the same object directly in JAX if we are willing to handle transforms.

For mtgpax, this suggests a practical split:

default fast path:  MAP / VI in JAX
full Bayes path:    JAX + BlackJAX or NumPyro
reference path:     Stan/CmdStanPy for smaller validation models

Linear Regression Model

The example model is a Bayesian linear regression:

\[ y_n \sim \mathcal{N}(\alpha + x_n^\top\beta,\sigma), \]

with priors:

\[ \alpha \sim \mathcal{N}(0,5), \qquad \beta_p \sim \mathcal{N}(0,2.5), \qquad \sigma \sim \operatorname{Exponential}(0.5). \]

Stan Program

data {
  int<lower=0> N, N_new, P;
  matrix[N, P] x;
  vector[N] y;
  matrix[N_new, P] x_new;
}
parameters {
  real alpha;
  vector[P] beta;
  real<lower=0> sigma;
}
model {
  alpha ~ normal(0, 5);
  beta ~ normal(0, 2.5);
  sigma ~ exponential(0.5);
  y ~ normal(alpha + x * beta, sigma);
}
generated quantities {
  vector[N_new] y_new = to_vector(normal_rng(alpha + x_new * beta, sigma));
}

Stan distribution statements are target-density increments. The model block is equivalent to:

target += normal_lupdf(alpha | 0, 5);
target += normal_lupdf(beta | 0, 2.5);
target += exponential_lupdf(sigma | 0.5);
target += normal_lupdf(y | alpha + x * beta, sigma);

Simulate Data

import functools
import json
from pathlib import Path

import numpy as np


def simulate_regression(n=128, p=2, n_new=4, seed=145777):
    def simulate_covariates(n):
        x = rng.normal(size=(n, p))
        x[:, 1] = x[:, 1] ** 2
        return x

    rng = np.random.default_rng(seed)
    alpha = rng.normal(0.0, 5.0)
    beta = rng.normal(0.0, 2.5, size=p)
    sigma = rng.exponential(1.0 / 0.5)
    x = simulate_covariates(n)
    mu = alpha + x @ beta
    y = rng.normal(mu, sigma)
    x_new = simulate_covariates(n_new)

    parameters = {"alpha": alpha, "beta": beta, "sigma": sigma}
    data = {
        "N": n,
        "P": p,
        "N_new": n_new,
        "x": x.tolist(),
        "y": y.tolist(),
        "x_new": x_new.tolist(),
    }
    return parameters, data


params, data = simulate_regression()
print(
    f"alpha:{params['alpha']:7.3f}; "
    f"beta[0]:{params['beta'][0]:7.3f}; "
    f"beta[1]:{params['beta'][1]:7.3f}; "
    f"sigma:{params['sigma']:7.3f}"
)
alpha: -9.147; beta[0]: -4.819; beta[1]:  1.146; sigma:  0.595

CmdStanPy Fragment

The CmdStanPy code in the paper is compact:

import cmdstanpy as csp

m = csp.CmdStanModel(stan_file="linear-regression.stan")
fit = m.sample(data=data, show_progress=False)
summary_csp = fit.summary(sig_figs=3)

This render now checks whether CmdStan is available, compiles the Stan program, runs NUTS through CmdStanPy, and prints a compact posterior summary. The run uses fewer draws than Stan’s full defaults so the page renders quickly, but it is doing real CmdStanPy compilation and sampling.

import cmdstanpy as csp

stan_code = r"""
data {
  int<lower=0> N, N_new, P;
  matrix[N, P] x;
  vector[N] y;
  matrix[N_new, P] x_new;
}
parameters {
  real alpha;
  vector[P] beta;
  real<lower=0> sigma;
}
model {
  alpha ~ normal(0, 5);
  beta ~ normal(0, 2.5);
  sigma ~ exponential(0.5);
  y ~ normal(alpha + x * beta, sigma);
}
generated quantities {
  vector[N_new] y_new = to_vector(normal_rng(alpha + x_new * beta, sigma));
}
"""

work = Path("dj_replication_files")
work.mkdir(exist_ok=True)
(work / "linear-regression.stan").write_text(stan_code)
(work / "linear-regression-data.json").write_text(json.dumps(data))

try:
    cmdstan_path = csp.cmdstan_path()
    print(f"CmdStan available: {cmdstan_path}")
except ValueError as e:
    raise RuntimeError("CmdStan is not installed; run cmdstanpy.install_cmdstan().") from e

m = csp.CmdStanModel(stan_file=str(work / "linear-regression.stan"))
fit = m.sample(
    data=data,
    chains=4,
    parallel_chains=4,
    iter_warmup=500,
    iter_sampling=500,
    show_progress=False,
    seed=441_582,
)

summary_csp = fit.summary(sig_figs=3)
print(summary_csp.loc[["alpha", "beta[1]", "beta[2]", "sigma"]])

stan_draws_xr = fit.draws_xr()
CmdStan available: /Users/alal/.cmdstan/cmdstan-2.34.1
         Mean      MCSE  StdDev    5%   50%   95%    N_Eff   N_Eff/s  R_hat
alpha   -9.17  0.001340  0.0576 -9.27 -9.17 -9.08  1840.00  33400.00  0.999
beta[1] -4.81  0.001120  0.0485 -4.89 -4.81 -4.73  1875.00  34082.00  0.999
beta[2]  1.15  0.000728  0.0316  1.10  1.15  1.20  1887.00  34305.00  1.000
sigma    0.54  0.000000  0.0400  0.49  0.54  0.61  1406.16  25566.47  1.000

What Stan’s Transpiled C++ Gives You

The paper’s section 5 emphasizes that Stan compiles the model into a C++ class with an unconstrained log-density method. The important pieces are:

template <bool propto, bool jacobian, typename T>
T log_prob_impl(...) const {
  // read unconstrained parameters
  // transform constrained parameters, e.g. sigma = exp(sigma_u)
  // add Jacobian term if jacobian == true
  // accumulate target += prior and likelihood terms
  // return scalar log density for autodiff
}

For this regression, the transform is:

\[ (\alpha, \beta, \sigma^{u}) \mapsto (\alpha, \beta, \exp(\sigma^{u})). \]

The log-Jacobian adjustment is:

\[ \log\left|\frac{d\exp(\sigma^u)}{d\sigma^u}\right| = \sigma^u. \]

That is exactly the manual transform we write below in JAX.

Direct JAX / BlackJAX Implementation

import jax

jax.config.update("jax_enable_x64", True)

import jax.numpy as jnp
import jax.random as jrd
from jax.scipy import stats
import blackjax

x = jnp.array(data["x"])
y = jnp.array(data["y"])
x_new = jnp.array(data["x_new"])


def log_posterior(params):
    lp = 0.0
    lp += jnp.sum(stats.norm.logpdf(params["alpha"], loc=0.0, scale=5.0))
    lp += jnp.sum(stats.norm.logpdf(params["beta"], loc=0.0, scale=2.5))
    lp += jnp.sum(stats.expon.logpdf(params["sigma"], scale=1.0 / 0.5))
    mu = params["alpha"] + x @ params["beta"]
    lp += jnp.sum(stats.norm.logpdf(y, loc=mu, scale=params["sigma"]))
    return lp

Transform and Jacobian

def transform(params):
    return {
        "alpha": params["alpha"],
        "beta": jnp.array(params["beta"]),
        "sigma": jnp.log(params["sigma"]),
    }


def inv_transform(t_params):
    log_adjust = 0.0
    sigma = jnp.exp(t_params["sigma"])
    log_adjust += t_params["sigma"]
    params = {
        "alpha": t_params["alpha"],
        "beta": jnp.array(t_params["beta"]),
        "sigma": sigma,
    }
    return params, log_adjust


def log_posterior_transformed(t_params):
    params_unconstrained, log_adjust = inv_transform(t_params)
    return log_posterior(params_unconstrained) + log_adjust

Random Initialization

def random_init_transformed(key):
    key0, key1, key2 = jrd.split(key, 3)
    return {
        "alpha": jrd.normal(key0),
        "beta": jrd.normal(key1, shape=(2,)),
        "sigma": jrd.normal(key2),
    }


seed = 441_582
key = jrd.key(seed)
init_key, nuts_key, pred_key = jrd.split(key, 3)
t_params_init = random_init_transformed(init_key)
params_init, log_adjust = inv_transform(t_params_init)

print(f"{t_params_init=}")
print(f"{params_init=}")
print(f"{log_adjust=}")
t_params_init={'alpha': Array(0.11128246, dtype=float64), 'beta': Array([1.10485799, 2.0610555 ], dtype=float64), 'sigma': Array(-1.51958322, dtype=float64)}
params_init={'alpha': Array(0.11128246, dtype=float64), 'beta': Array([1.10485799, 2.0610555 ], dtype=float64), 'sigma': Array(0.21880306, dtype=float64)}
log_adjust=Array(-1.51958322, dtype=float64)

NUTS with BlackJAX

This is the same structure as the DJ paper, but with fewer draws so the note renders quickly.

def random_markov_chain(key, kernel, init_state, num_draws):
    @jax.jit
    def one_step(state, key):
        state, _ = kernel(key, state)
        return state, state

    keys = jrd.split(key, num_draws)
    _, states = jax.lax.scan(one_step, init_state, keys)
    return states


def nuts_sample(key, log_density, init_position, num_draws=500, warmup_steps=500):
    warmup_key, sample_key = jrd.split(key, 2)
    warmup = blackjax.window_adaptation(blackjax.nuts, log_density)
    (state, tuned_params), _ = warmup.run(
        warmup_key,
        init_position,
        num_steps=warmup_steps,
    )
    kernel = blackjax.nuts(log_density, **tuned_params).step
    states = random_markov_chain(sample_key, kernel, state, num_draws)
    return states.position


t_draws = nuts_sample(
    nuts_key,
    log_posterior_transformed,
    t_params_init,
    num_draws=500,
    warmup_steps=500,
)

draws = jax.vmap(lambda tp: inv_transform(tp)[0])(t_draws)
posterior_means = jax.tree.map(functools.partial(jnp.mean, axis=0), draws)
posterior_stds = jax.tree.map(functools.partial(jnp.std, axis=0), draws)

print(f"{posterior_means=}")
print(f"{posterior_stds=}")
posterior_means={'alpha': Array(-9.17410714, dtype=float64), 'beta': Array([-4.80498087,  1.14562046], dtype=float64), 'sigma': Array(0.54329635, dtype=float64)}
posterior_stds={'alpha': Array(0.05667742, dtype=float64), 'beta': Array([0.04934707, 0.03229715], dtype=float64), 'sigma': Array(0.03457886, dtype=float64)}

Posterior Predictive Draws

def posterior_predictive(key, params):
    mu = params["alpha"] + x_new @ params["beta"]
    z = jax.random.normal(key, shape=mu.shape)
    y_new = mu + params["sigma"] * z
    return {"y_new": y_new}


def posterior_predictive_draws(key, draws):
    n = draws["alpha"].shape[0]
    keys = jax.random.split(key, n)
    return jax.vmap(posterior_predictive, in_axes=(0, 0))(keys, draws)


pred_draws = posterior_predictive_draws(pred_key, draws)
posterior_pred_means = jax.tree.map(functools.partial(jnp.mean, axis=0), pred_draws)
posterior_pred_stds = jax.tree.map(functools.partial(jnp.std, axis=0), pred_draws)

print(f"{posterior_pred_means=}")
print(f"{posterior_pred_stds=}")
posterior_pred_means={'y_new': Array([-10.56050834, -16.2718394 , -17.80268581,  -8.55189096], dtype=float64)}
posterior_pred_stds={'y_new': Array([0.54226297, 0.56244718, 0.56034512, 0.52502798], dtype=float64)}

Comparison: Sections 4/5 vs 6

Section 4 is short because Stan has already solved the hard interface problems:

  • constrained parameters,
  • initialization,
  • Jacobian adjustments,
  • log-density construction,
  • adaptation,
  • generated quantities,
  • posterior summaries.

Section 5 shows that this convenience comes from a generated C++ class that exposes an unconstrained autodiff log density plus generated-quantities forward simulation.

Section 6 shows the JAX equivalent. It is more verbose because transforms and Jacobian terms are manual, but it is transparent tensor code. For a custom MTGP, this is attractive: we can build exactly the factorized covariance and masking logic we want, then sample with BlackJAX/NumPyro.

For the MTGP project, the practical conclusion is:

  • use Stan/CmdStanPy as a trusted reference path for smaller models,
  • use direct JAX or NumPyro for the scalable full-Bayes path,
  • keep MAP/VI as the fast default.