Statistical Ideas in Code
  • Home
  • Papers
    • Random Effects
    • False Discovery Rate
    • Two Cultures
    • Bootstrap
    • EM Algorithm
    • Baum-Welch
    • MCMC
    • Cox Model
    • Propensity Scores
    • GAM
    • Conformal Prediction
    • Gradient Boosting
    • Marginal Structural Models
    • Lasso

On this page

  • Core Contribution
  • Minimal Implementation
  • Implementations

Maximum Likelihood from Incomplete Data via the EM Algorithm

Dempster, Laird, and Rubin (1977)

Core Contribution

The EM algorithm converts difficult incomplete-data likelihoods into a sequence of easier complete-data problems. Let \(Z\) be missing or latent data and \(\theta^{(t)}\) the current parameter value. The E-step forms

\[ Q(\theta\mid\theta^{(t)}) = E_{\theta^{(t)}}[\log p(Y,Z\mid\theta)\mid Y], \]

and the M-step updates

\[ \theta^{(t+1)}=\arg\max_\theta Q(\theta\mid\theta^{(t)}). \]

The observed likelihood is guaranteed not to decrease after each iteration. In a mixture model, the missing data are component labels.

Minimal Implementation

Simulate data where the missing variable is the latent component label \(Z_i\).

n = 240
z = rng.binomial(1, 0.38, n)
y = rng.normal(np.where(z == 1, 2.0, -1.2), np.where(z == 1, 0.55, 0.75))
y[:8]
array([-1.15671141,  1.87257003,  0.0641122 , -0.65140745, -1.45503641,
       -2.45273291,  2.620375  ,  1.60487901])

Initialize \(\theta=(\pi,\mu,\sigma)\), then alternate the E-step responsibility \(w_i=\Pr(Z_i=1\mid y_i,\theta)\) and M-step weighted moments.

pi, mu, sig = 0.5, np.array([-0.2, 1.2]), np.array([1.0, 1.0])
ll = []
for _ in range(50):
    d0 = (1 - pi) * stats.norm.pdf(y, mu[0], sig[0])
    d1 = pi * stats.norm.pdf(y, mu[1], sig[1])
    w = d1 / (d0 + d1)
    pi = w.mean()
    mu = np.array([(1 - w) @ y / (1 - w).sum(), w @ y / w.sum()])
    sig = np.sqrt([((1 - w) * (y - mu[0])**2).sum() / (1 - w).sum(),
                   (w * (y - mu[1])**2).sum() / w.sum()])
    ll.append(np.log(d0 + d1).sum())
pi, mu, sig, ll[-1]
(np.float64(0.3971956703149308),
 array([-1.24267976,  2.09595405]),
 array([0.76860609, 0.55888281]),
 np.float64(-404.5923374138181))

Plot the fitted mixture and the monotone observed-data log likelihood.

grid = np.linspace(y.min() - 1, y.max() + 1, 300)
mix = (1 - pi) * stats.norm.pdf(grid, mu[0], sig[0]) + pi * stats.norm.pdf(grid, mu[1], sig[1])
fig, ax = plt.subplots(1, 2, figsize=(8, 3.4))
ax[0].hist(y, bins=30, density=True, alpha=0.55)
ax[0].plot(grid, mix, color="#ffcc66", lw=2.5)
ax[0].set(title="fitted mixture")
ax[1].plot(ll, marker="o", ms=3)
ax[1].set(title="monotone likelihood", xlabel="iteration")
plt.show()

EM alternates posterior responsibilities and weighted Gaussian updates.

Implementations

scikit-learn GaussianMixture, mclust, mixtools