Source: Arsenic/arsenic_logistic_building.Rmd
This ports the core model-building sequence for the Bangladesh wells example.
Load data
Code
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import statsmodels.formula.api as smf
from scipy.special import expit
root = Path("../../ROS-Examples" )
wells = pd.read_csv(root / "Arsenic/data/wells.csv" )
wells["dist100" ] = wells["dist" ] / 100
wells.head()
0
1
2.36
16.826000
0.16826
0
0
0.0
1
1
0.71
47.321999
0.47322
0
0
0.0
2
0
2.07
20.966999
0.20967
0
10
2.5
3
1
1.15
21.486000
0.21486
0
12
3.0
4
1
1.10
40.874001
0.40874
1
14
3.5
Null-model log scores
Code
y = wells["switch" ].to_numpy()
for p in [0.5 , y.mean()]:
log_score = np.sum (y* np.log(p) + (1 - y)* np.log(1 - p))
print (round (p, 3 ), round (log_score, 1 ))
0.5 -2093.3
0.575 -2059.0
Single predictor: distance
R original:
stan_glm (switch ~ dist100, family = binomial (link = "logit" ), data= wells)
Python frequentist analog:
Code
fit_2 = smf.logit("switch ~ dist100" , data= wells).fit()
fit_2.params
Optimization terminated successfully.
Current function value: 0.674874
Iterations 4
Intercept 0.605959
dist100 -0.621882
dtype: float64
Code
fig, ax = plt.subplots()
rng = np.random.default_rng(123 )
y_jit = y + (1 - 2 * y) * rng.uniform(0 , 0.05 , size= len (y))
ax.scatter(wells["dist" ], y_jit, s= 4 , color= "black" , alpha= 0.3 )
xs = np.linspace(0 , wells["dist" ].max (), 300 )
ax.plot(xs, expit(fit_2.params["Intercept" ] + fit_2.params["dist100" ] * xs/ 100 ), color= "black" )
ax.set_xlabel("Distance to nearest safe well (meters)" )
ax.set_ylabel("Pr(switching)" )
Text(0, 0.5, 'Pr(switching)')
Two predictors: distance + arsenic
Code
fit_3 = smf.logit("switch ~ dist100 + arsenic" , data= wells).fit()
fit_3.params
Optimization terminated successfully.
Current function value: 0.650773
Iterations 5
Intercept 0.002749
dist100 -0.896644
arsenic 0.460775
dtype: float64
Compare predicted curves at fixed arsenic levels:
Code
fig, ax = plt.subplots()
ax.scatter(wells["dist" ], y_jit, s= 4 , color= "black" , alpha= 0.25 )
for a, color in [(0.5 , "gray" ), (1.0 , "black" )]:
p = expit(fit_3.params["Intercept" ] + fit_3.params["dist100" ]* xs/ 100 + fit_3.params["arsenic" ]* a)
ax.plot(xs, p, color= color, label= f"arsenic= { a} " )
ax.legend()
ax.set_xlabel("Distance to nearest safe well (meters)" )
ax.set_ylabel("Pr(switching)" )
Text(0, 0.5, 'Pr(switching)')
Interaction
Code
fit_4 = smf.logit("switch ~ dist100 * arsenic" , data= wells).fit()
fit_4.params
Optimization terminated successfully.
Current function value: 0.650270
Iterations 5
Intercept -0.147868
dist100 -0.577218
arsenic 0.555977
dist100:arsenic -0.178906
dtype: float64
CmdStanPy equivalent
For a Bayesian analog, use Bernoulli-logit regression:
data {
int <lower =1 > N;
int <lower =1 > K;
matrix [N, K] X;
array [N] int <lower =0 ,upper =1 > y;
}
parameters {
vector [K] beta;
}
model {
beta ~ normal(0 , 2.5 );
y ~ bernoulli_logit(X * beta);
}
BlackJAX relevance
This is a clean example for a hand-written logistic-regression log density:
log_lik = sum (y * log_sigmoid(X @ beta) + (1 - y) * log_sigmoid(- (X @ beta)))
log_prior = sum (norm.logpdf(beta, 0 , 2.5 ))
BlackJAX is useful if comparing NUTS behavior, priors under separation, or custom PSIS/LOO workflows. For ordinary model-building exposition, CmdStanPy/PyMC are clearer.