Logistic regression priors

Source: LogisticPriors/logistic_priors.Rmd

This example shows why priors matter in logistic regression, especially with small samples. The source R function simulates binary data from a latent logistic variable, fits both glm and stan_glm, and repeats the exercise for n = 10, 100, 1000.

Setup

Code
import numpy as np
import pandas as pd
import statsmodels.api as sm
from scipy.special import expit

Simulating the example

The R code generates

[ x_i (-1,1), z_i (a + b x_i, 1), y_i = 1[z_i > 0]. ]

Because Pr(z_i > 0) = logit^{-1}(a + b x_i), we can simulate y directly from a Bernoulli distribution.

Code
def simulate_logistic_data(n, a=-2.0, b=0.8, seed=None):
    rng = np.random.default_rng(seed)
    x = rng.uniform(-1, 1, size=n)
    p = expit(a + b * x)
    y = rng.binomial(1, p)
    return pd.DataFrame({"x": x, "y": y, "p_true": p})

MLE and a normal-prior approximation

The R page compares maximum likelihood to stan_glm(..., prior = normal(0.5, 0.5, autoscale=FALSE)). In Python, a simple way to see the same shrinkage is to compute a posterior mode under independent Normal priors for the intercept and slope. The function below uses a Gaussian approximation around that mode for standard errors.

Code
def fit_logistic_mle(data):
    X = sm.add_constant(data[["x"]])
    return sm.Logit(data["y"], X).fit(disp=False)


def fit_logistic_normal_prior(data, prior_mean=0.5, prior_sd=0.5):
    X = sm.add_constant(data[["x"]]).to_numpy()
    y = data["y"].to_numpy()
    prior_mean_vec = np.array([0.0, prior_mean])
    prior_prec = np.diag([1 / 2.5**2, 1 / prior_sd**2])

    def objective(beta):
        eta = X @ beta
        loglik = np.sum(y * eta - np.logaddexp(0, eta))
        diff = beta - prior_mean_vec
        logprior = -0.5 * diff @ prior_prec @ diff
        return -(loglik + logprior)

    def gradient(beta):
        eta = X @ beta
        p = expit(eta)
        diff = beta - prior_mean_vec
        return -(X.T @ (y - p) - prior_prec @ diff)

    from scipy.optimize import minimize

    opt = minimize(objective, x0=np.zeros(2), jac=gradient, method="BFGS")
    beta = opt.x
    p = expit(X @ beta)
    W = p * (1 - p)
    hessian = X.T @ (W[:, None] * X) + prior_prec
    cov = np.linalg.inv(hessian)
    return pd.Series(beta, index=["const", "x"]), pd.Series(np.sqrt(np.diag(cov)), index=["const", "x"])
Code
def bayes_sim(n, a=-2.0, b=0.8, seed=363852):
    data = simulate_logistic_data(n, a=a, b=b, seed=seed)
    mle = fit_logistic_mle(data)
    post_mean, post_se = fit_logistic_normal_prior(data)
    return pd.DataFrame(
        {
            "estimate_mle": mle.params,
            "se_mle": mle.bse,
            "estimate_normal_prior_mode": post_mean,
            "se_normal_prior_approx": post_se,
        }
    )

results = {n: bayes_sim(n, seed=363852 + n) for n in [10, 100, 1000]}
results[10]
estimate_mle se_mle estimate_normal_prior_mode se_normal_prior_approx
const -2.288005 1.160043 -1.902734 0.885274
x 0.978329 1.747827 0.531777 0.472606
Code
pd.concat(results, names=["n", "term"])
estimate_mle se_mle estimate_normal_prior_mode se_normal_prior_approx
n term
10 const -2.288005 1.160043 -1.902734 0.885274
x 0.978329 1.747827 0.531777 0.472606
100 const -2.621656 0.454423 -2.367570 0.354713
x 1.509546 0.709560 0.853834 0.386537
1000 const -1.997080 0.099781 -1.991432 0.099040
x 0.665692 0.175332 0.646393 0.164994

With n = 10, the likelihood can be weak and the prior can visibly pull the slope toward its prior center. By n = 1000, the likelihood dominates and the two estimates are typically close.

CmdStanPy version

For a closer analogue to rstanarm::stan_glm, sample the two-parameter logistic model directly. The slope prior below matches the source example’s normal(0.5, 0.5); the intercept prior is weakly informative.

Code
from pathlib import Path
from cmdstanpy import CmdStanModel

stan_code = """
data {
  int<lower=1> N;
  vector[N] x;
  array[N] int<lower=0, upper=1> y;
}
parameters {
  real alpha;
  real beta;
}
model {
  alpha ~ normal(0, 2.5);
  beta ~ normal(0.5, 0.5);
  y ~ bernoulli_logit(alpha + beta * x);
}
"""
stan_file = Path("_generated/logistic_prior_demo.stan")
stan_file.parent.mkdir(exist_ok=True)
stan_file.write_text(stan_code)
model = CmdStanModel(stan_file=str(stan_file))

data = simulate_logistic_data(100, seed=363852)
fit = model.sample(
    data={"N": len(data), "x": data["x"].to_numpy(), "y": data["y"].astype(int).to_list()},
    seed=363852,
    chains=4,
    parallel_chains=4,
    show_progress=False,
)
fit.draws_pd()[["alpha", "beta"]].describe(percentiles=[0.05, 0.5, 0.95])
alpha beta
count 4000.000000 4000.000000
mean -2.499834 0.443049
std 0.369157 0.392989
min -3.925290 -0.771283
5% -3.132582 -0.215418
50% -2.483970 0.436503
95% -1.934624 1.087618
max -1.377130 1.832090

BlackJAX log-density sketch

The model is small enough that the complete log density is also easy to write directly. This is the target one would pass to a custom sampler.

Code
from scipy.special import expit

x_np = data["x"].to_numpy()
y_np = data["y"].to_numpy()

def logistic_prior_logdensity(position):
    alpha = position["alpha"]
    beta = position["beta"]
    eta = alpha + beta * x_np
    p = np.clip(expit(eta), 1e-12, 1 - 1e-12)
    loglik = np.sum(y_np * np.log(p) + (1 - y_np) * np.log1p(-p))
    logprior_alpha = -0.5 * (alpha / 2.5) ** 2
    logprior_beta = -0.5 * ((beta - 0.5) / 0.5) ** 2
    return loglik + logprior_alpha + logprior_beta

logistic_prior_logdensity({"alpha": 0.0, "beta": 0.5})
np.float64(-70.7820037648752)