Different software options: linear regression

Source: DifferentSoftware/linear.Rmd

This example is a deliberately small linear regression used to compare fitting interfaces. In Python the analogous ladder is:

  1. statsmodels formula interface for ordinary least squares,
  2. matrix least squares with numpy,
  3. CmdStanPy for an explicit Bayesian Gaussian regression,
  4. BlackJAX for the same model as a hand-written log density.

Simulate fake data

Code
import numpy as np
import pandas as pd
import statsmodels.formula.api as smf

rng = np.random.default_rng(2141)
N = 100
beta_true = np.array([1.0, 2.0, 3.0])
x1 = rng.normal(size=N)
x2 = rng.normal(size=N)
X = np.column_stack([np.ones(N), x1, x2])
sigma = 2.0
y = X @ beta_true + rng.normal(0, sigma, size=N)
dat = pd.DataFrame({"y": y, "x1": x1, "x2": x2})
dat.head()
y x1 x2
0 1.211638 0.140604 0.746172
1 13.013274 1.368083 2.716414
2 -0.784397 1.793252 -1.252496
3 1.432282 -0.308162 -0.008264
4 -5.842190 0.574745 -1.491098

Formula OLS: lm(y ~ x1 + x2)

Code
fit1 = smf.ols("y ~ x1 + x2", data=dat).fit()
fit1.params
Intercept    1.282021
x1           1.736972
x2           2.968017
dtype: float64
Code
pd.DataFrame({"estimate": fit1.params, "std_error": fit1.bse})
estimate std_error
Intercept 1.282021 0.199704
x1 1.736972 0.216956
x2 2.968017 0.193003

Matrix least squares

Code
beta_hat, *_ = np.linalg.lstsq(X, y, rcond=None)
resid = y - X @ beta_hat
sigma_hat = np.sqrt((resid @ resid) / (N - X.shape[1]))
vcov = sigma_hat**2 * np.linalg.inv(X.T @ X)
se = np.sqrt(np.diag(vcov))
pd.DataFrame({"estimate": beta_hat, "std_error": se}, index=["Intercept", "x1", "x2"])
estimate std_error
Intercept 1.282021 0.199704
x1 1.736972 0.216956
x2 2.968017 0.193003

CmdStanPy version

This is the explicit equivalent of rstanarm::stan_glm(y ~ x1 + x2).

data {
  int<lower=1> N;
  int<lower=1> K;
  matrix[N, K] X;
  vector[N] y;
}
parameters {
  vector[K] beta;
  real<lower=0> sigma;
}
model {
  beta ~ normal(0, 10);
  sigma ~ exponential(1);
  y ~ normal(X * beta, sigma);
}
Code
from pathlib import Path
from cmdstanpy import CmdStanModel

stan_code = r'''
data { int<lower=1> N; int<lower=1> K; matrix[N, K] X; vector[N] y; }
parameters { vector[K] beta; real<lower=0> sigma; }
model { beta ~ normal(0, 10); sigma ~ exponential(1); y ~ normal(X * beta, sigma); }
'''
stan_file = Path("models/linear_gaussian.stan")
stan_file.parent.mkdir(exist_ok=True)
stan_file.write_text(stan_code)
# model = CmdStanModel(stan_file=stan_file)
# fit = model.sample(data={"N": N, "K": 3, "X": X, "y": y}, seed=2141)
208

BlackJAX log-density sketch

BlackJAX is useful when the point is sampler mechanics rather than formula convenience.

Code
# import jax.numpy as jnp
# import jax.scipy.stats as stats
#
# def logdensity(pos):
#     beta = pos["beta"]
#     sigma = jnp.exp(pos["log_sigma"])
#     lp = stats.norm.logpdf(pos["log_sigma"], 0, 1)  # prior on unconstrained scale
#     lp += jnp.sum(stats.norm.logpdf(beta, 0, 10))
#     lp += pos["log_sigma"]  # Jacobian for sigma = exp(log_sigma)
#     lp += jnp.sum(stats.norm.logpdf(y, X @ beta, sigma))
#     return lp