This note replicates the central code fragments from Bob Carpenter’s DJ paper, especially the comparison between:
Section 4: sampling a Stan linear regression through CmdStanPy,
Section 5: what Stan’s generated C++ model object is doing conceptually,
Section 6: writing the same model directly in JAX and sampling with BlackJAX.
The original page is archived locally at ref/dj-paper/dj-paper.html and ref/dj-paper/dj-paper.md.
Takeaway for mtgpax
For mtgpax, the direct JAX path is attractive for full Bayes because it gives us:
differentiable tensor code,
PyTree parameter containers rather than serialized parameter vectors,
explicit transforms and Jacobian adjustments,
BlackJAX/NumPyro-style HMC/NUTS on JAX arrays,
an easier path to GPU/TPU acceleration than Stan.
Stan still wins on mature diagnostics, robust default NUTS behavior, and not making the user hand-write transforms. The DJ paper’s point is that the gap is not philosophical: Stan compiles a constrained model to an unconstrained differentiable log density; we can write the same object directly in JAX if we are willing to handle transforms.
For mtgpax, this suggests a practical split:
default fast path: MAP / VI in JAX
full Bayes path: JAX + BlackJAX or NumPyro
reference path: Stan/CmdStanPy for smaller validation models
Linear Regression Model
The example model is a Bayesian linear regression:
This render now checks whether CmdStan is available, compiles the Stan program, runs NUTS through CmdStanPy, and prints a compact posterior summary. The run uses fewer draws than Stan’s full defaults so the page renders quickly, but it is doing real CmdStanPy compilation and sampling.
import cmdstanpy as cspstan_code =r"""data { int<lower=0> N, N_new, P; matrix[N, P] x; vector[N] y; matrix[N_new, P] x_new;}parameters { real alpha; vector[P] beta; real<lower=0> sigma;}model { alpha ~ normal(0, 5); beta ~ normal(0, 2.5); sigma ~ exponential(0.5); y ~ normal(alpha + x * beta, sigma);}generated quantities { vector[N_new] y_new = to_vector(normal_rng(alpha + x_new * beta, sigma));}"""work = Path("dj_replication_files")work.mkdir(exist_ok=True)(work /"linear-regression.stan").write_text(stan_code)(work /"linear-regression-data.json").write_text(json.dumps(data))try: cmdstan_path = csp.cmdstan_path()print(f"CmdStan available: {cmdstan_path}")exceptValueErroras e:raiseRuntimeError("CmdStan is not installed; run cmdstanpy.install_cmdstan().") from em = csp.CmdStanModel(stan_file=str(work /"linear-regression.stan"))fit = m.sample( data=data, chains=4, parallel_chains=4, iter_warmup=500, iter_sampling=500, show_progress=False, seed=441_582,)summary_csp = fit.summary(sig_figs=3)print(summary_csp.loc[["alpha", "beta[1]", "beta[2]", "sigma"]])stan_draws_xr = fit.draws_xr()
Section 4 is short because Stan has already solved the hard interface problems:
constrained parameters,
initialization,
Jacobian adjustments,
log-density construction,
adaptation,
generated quantities,
posterior summaries.
Section 5 shows that this convenience comes from a generated C++ class that exposes an unconstrained autodiff log density plus generated-quantities forward simulation.
Section 6 shows the JAX equivalent. It is more verbose because transforms and Jacobian terms are manual, but it is transparent tensor code. For a custom MTGP, this is attractive: we can build exactly the factorized covariance and masking logic we want, then sample with BlackJAX/NumPyro.
For the MTGP project, the practical conclusion is:
use Stan/CmdStanPy as a trusted reference path for smaller models,
use direct JAX or NumPyro for the scalable full-Bayes path,