Core Contribution
Marginal structural models target causal effects of treatment regimes when confounders both affect later treatment and are affected by earlier treatment. A simple model is
\[
E[Y^{\bar a}] = \beta_0+\beta_1 a_0+\beta_2 a_1.
\]
Inverse probability weights create a pseudo-population in which treatment history is independent of measured confounder history:
\[
W_i=\prod_t \frac{1}{\Pr(A_{it}=a_{it}\mid \bar A_{i,t-1},\bar L_{it})}.
\]
Then a weighted regression estimates the marginal causal model.
Minimal Implementation
Simulate time-varying confounding: \(L_1\) is affected by \(A_0\) and also predicts \(A_1\) and \(Y\) .
n = 1200
L0 = rng.normal(size= n)
p0 = special.expit(- 0.1 + 0.9 * L0)
A0 = rng.binomial(1 , p0)
L1 = 0.8 * L0 + 0.9 * A0 + rng.normal(size= n)
p1 = special.expit(- 0.2 + 0.8 * L1 + 0.5 * A0)
A1 = rng.binomial(1 , p1)
Y = 1.0 * A0 + 1.3 * A1 + 0.9 * L0 + 1.0 * L1 + rng.normal(size= n)
A0.mean(), A1.mean()
(np.float64(0.4825), np.float64(0.5941666666666666))
Estimate the treatment probabilities in the denominator of \(W_i\) .
def fit_logit(X, d):
def obj(b):
xb = X @ b
return - np.sum (d * xb - np.logaddexp(0 , xb))
return optimize.minimize(obj, np.zeros(X.shape[1 ])).x
ph0 = special.expit(np.c_[np.ones(n), L0] @ fit_logit(np.c_[np.ones(n), L0], A0))
ph1 = special.expit(np.c_[np.ones(n), L1, A0] @ fit_logit(np.c_[np.ones(n), L1, A0], A1))
w = (A0/ ph0 + (1 - A0)/ (1 - ph0)) * (A1/ ph1 + (1 - A1)/ (1 - ph1))
w.mean(), np.quantile(w, [0.05 , 0.95 ])
(np.float64(4.076379119622979), array([ 1.40580129, 10.03347031]))
Fit the marginal structural model \(E[Y^{\bar a}]=\beta_0+\beta_1 a_0+\beta_2 a_1\) by weighted least squares.
ordinary = linalg.lstsq(np.c_[np.ones(n), A0, A1, L0, L1], Y)[0 ][1 :3 ]
X_msm = np.c_[np.ones(n), A0, A1]
sw = np.sqrt(w / w.mean())
msm = linalg.lstsq(X_msm * sw[:, None ], Y * sw)[0 ][1 :3 ]
ordinary, msm
(array([0.95949876, 1.21576968]), array([1.37548673, 1.03425169]))
Compare ordinary adjusted coefficients with the IPW marginal structural estimates.
fig, ax = plt.subplots(figsize= (6 , 3.4 ))
loc = np.arange(2 )
ax.bar(loc - 0.18 , ordinary, width= 0.35 , label= "ordinary adjusted" )
ax.bar(loc + 0.18 , msm, width= 0.35 , label= "IPW MSM" )
ax.scatter(loc, [1.0 , 1.3 ], color= "#66ff99" , marker= "x" , s= 70 , label= "truth" )
ax.set (xticks= loc, xticklabels= ["A0" , "A1" ], ylabel= "coefficient" )
ax.legend()
plt.show()