Arsenic wells: building logistic regression models

Source: Arsenic/arsenic_logistic_building.Rmd

This ports the core model-building sequence for the Bangladesh wells example.

Load data

Code
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import statsmodels.formula.api as smf
from scipy.special import expit

root = Path("../../ROS-Examples")
wells = pd.read_csv(root / "Arsenic/data/wells.csv")
wells["dist100"] = wells["dist"] / 100
wells.head()
switch arsenic dist dist100 assoc educ educ4
0 1 2.36 16.826000 0.16826 0 0 0.0
1 1 0.71 47.321999 0.47322 0 0 0.0
2 0 2.07 20.966999 0.20967 0 10 2.5
3 1 1.15 21.486000 0.21486 0 12 3.0
4 1 1.10 40.874001 0.40874 1 14 3.5

Null-model log scores

Code
y = wells["switch"].to_numpy()
for p in [0.5, y.mean()]:
    log_score = np.sum(y*np.log(p) + (1-y)*np.log(1-p))
    print(round(p, 3), round(log_score, 1))
0.5 -2093.3
0.575 -2059.0

Single predictor: distance

R original:

stan_glm(switch ~ dist100, family = binomial(link = "logit"), data=wells)

Python frequentist analog:

Code
fit_2 = smf.logit("switch ~ dist100", data=wells).fit()
fit_2.params
Optimization terminated successfully.
         Current function value: 0.674874
         Iterations 4
Intercept    0.605959
dist100     -0.621882
dtype: float64
Code
fig, ax = plt.subplots()
rng = np.random.default_rng(123)
y_jit = y + (1 - 2*y) * rng.uniform(0, 0.05, size=len(y))
ax.scatter(wells["dist"], y_jit, s=4, color="black", alpha=0.3)
xs = np.linspace(0, wells["dist"].max(), 300)
ax.plot(xs, expit(fit_2.params["Intercept"] + fit_2.params["dist100"] * xs/100), color="black")
ax.set_xlabel("Distance to nearest safe well (meters)")
ax.set_ylabel("Pr(switching)")
Text(0, 0.5, 'Pr(switching)')

Two predictors: distance + arsenic

Code
fit_3 = smf.logit("switch ~ dist100 + arsenic", data=wells).fit()
fit_3.params
Optimization terminated successfully.
         Current function value: 0.650773
         Iterations 5
Intercept    0.002749
dist100     -0.896644
arsenic      0.460775
dtype: float64

Compare predicted curves at fixed arsenic levels:

Code
fig, ax = plt.subplots()
ax.scatter(wells["dist"], y_jit, s=4, color="black", alpha=0.25)
for a, color in [(0.5, "gray"), (1.0, "black")]:
    p = expit(fit_3.params["Intercept"] + fit_3.params["dist100"]*xs/100 + fit_3.params["arsenic"]*a)
    ax.plot(xs, p, color=color, label=f"arsenic={a}")
ax.legend()
ax.set_xlabel("Distance to nearest safe well (meters)")
ax.set_ylabel("Pr(switching)")
Text(0, 0.5, 'Pr(switching)')

Interaction

Code
fit_4 = smf.logit("switch ~ dist100 * arsenic", data=wells).fit()
fit_4.params
Optimization terminated successfully.
         Current function value: 0.650270
         Iterations 5
Intercept         -0.147868
dist100           -0.577218
arsenic            0.555977
dist100:arsenic   -0.178906
dtype: float64

CmdStanPy equivalent

For a Bayesian analog, use Bernoulli-logit regression:

data {
  int<lower=1> N;
  int<lower=1> K;
  matrix[N, K] X;
  array[N] int<lower=0,upper=1> y;
}
parameters {
  vector[K] beta;
}
model {
  beta ~ normal(0, 2.5);
  y ~ bernoulli_logit(X * beta);
}

BlackJAX relevance

This is a clean example for a hand-written logistic-regression log density:

log_lik = sum(y * log_sigmoid(X @ beta) + (1-y) * log_sigmoid(-(X @ beta)))
log_prior = sum(norm.logpdf(beta, 0, 2.5))

BlackJAX is useful if comparing NUTS behavior, priors under separation, or custom PSIS/LOO workflows. For ordinary model-building exposition, CmdStanPy/PyMC are clearer.