Source: DifferentSoftware/linear.Rmd
This example is a deliberately small linear regression used to compare fitting interfaces. In Python the analogous ladder is:
statsmodels formula interface for ordinary least squares,
- matrix least squares with
numpy,
- CmdStanPy for an explicit Bayesian Gaussian regression,
- BlackJAX for the same model as a hand-written log density.
Simulate fake data
Code
import numpy as np
import pandas as pd
import statsmodels.formula.api as smf
rng = np.random.default_rng(2141)
N = 100
beta_true = np.array([1.0, 2.0, 3.0])
x1 = rng.normal(size=N)
x2 = rng.normal(size=N)
X = np.column_stack([np.ones(N), x1, x2])
sigma = 2.0
y = X @ beta_true + rng.normal(0, sigma, size=N)
dat = pd.DataFrame({"y": y, "x1": x1, "x2": x2})
dat.head()
| 0 |
1.211638 |
0.140604 |
0.746172 |
| 1 |
13.013274 |
1.368083 |
2.716414 |
| 2 |
-0.784397 |
1.793252 |
-1.252496 |
| 3 |
1.432282 |
-0.308162 |
-0.008264 |
| 4 |
-5.842190 |
0.574745 |
-1.491098 |
Matrix least squares
Code
beta_hat, *_ = np.linalg.lstsq(X, y, rcond=None)
resid = y - X @ beta_hat
sigma_hat = np.sqrt((resid @ resid) / (N - X.shape[1]))
vcov = sigma_hat**2 * np.linalg.inv(X.T @ X)
se = np.sqrt(np.diag(vcov))
pd.DataFrame({"estimate": beta_hat, "std_error": se}, index=["Intercept", "x1", "x2"])
| Intercept |
1.282021 |
0.199704 |
| x1 |
1.736972 |
0.216956 |
| x2 |
2.968017 |
0.193003 |
CmdStanPy version
This is the explicit equivalent of rstanarm::stan_glm(y ~ x1 + x2).
data {
int<lower=1> N;
int<lower=1> K;
matrix[N, K] X;
vector[N] y;
}
parameters {
vector[K] beta;
real<lower=0> sigma;
}
model {
beta ~ normal(0, 10);
sigma ~ exponential(1);
y ~ normal(X * beta, sigma);
}
Code
from pathlib import Path
from cmdstanpy import CmdStanModel
stan_code = r'''
data { int<lower=1> N; int<lower=1> K; matrix[N, K] X; vector[N] y; }
parameters { vector[K] beta; real<lower=0> sigma; }
model { beta ~ normal(0, 10); sigma ~ exponential(1); y ~ normal(X * beta, sigma); }
'''
stan_file = Path("models/linear_gaussian.stan")
stan_file.parent.mkdir(exist_ok=True)
stan_file.write_text(stan_code)
# model = CmdStanModel(stan_file=stan_file)
# fit = model.sample(data={"N": N, "K": 3, "X": X, "y": y}, seed=2141)
BlackJAX log-density sketch
BlackJAX is useful when the point is sampler mechanics rather than formula convenience.
Code
# import jax.numpy as jnp
# import jax.scipy.stats as stats
#
# def logdensity(pos):
# beta = pos["beta"]
# sigma = jnp.exp(pos["log_sigma"])
# lp = stats.norm.logpdf(pos["log_sigma"], 0, 1) # prior on unconstrained scale
# lp += jnp.sum(stats.norm.logpdf(beta, 0, 10))
# lp += pos["log_sigma"] # Jacobian for sigma = exp(log_sigma)
# lp += jnp.sum(stats.norm.logpdf(y, X @ beta, sigma))
# return lp