n = 240
z = rng.binomial(1, 0.38, n)
y = rng.normal(np.where(z == 1, 2.0, -1.2), np.where(z == 1, 0.55, 0.75))
y[:8]array([-1.15671141, 1.87257003, 0.0641122 , -0.65140745, -1.45503641,
-2.45273291, 2.620375 , 1.60487901])
Dempster, Laird, and Rubin (1977)
The EM algorithm converts difficult incomplete-data likelihoods into a sequence of easier complete-data problems. Let \(Z\) be missing or latent data and \(\theta^{(t)}\) the current parameter value. The E-step forms
\[ Q(\theta\mid\theta^{(t)}) = E_{\theta^{(t)}}[\log p(Y,Z\mid\theta)\mid Y], \]
and the M-step updates
\[ \theta^{(t+1)}=\arg\max_\theta Q(\theta\mid\theta^{(t)}). \]
The observed likelihood is guaranteed not to decrease after each iteration. In a mixture model, the missing data are component labels.
Simulate data where the missing variable is the latent component label \(Z_i\).
array([-1.15671141, 1.87257003, 0.0641122 , -0.65140745, -1.45503641,
-2.45273291, 2.620375 , 1.60487901])
Initialize \(\theta=(\pi,\mu,\sigma)\), then alternate the E-step responsibility \(w_i=\Pr(Z_i=1\mid y_i,\theta)\) and M-step weighted moments.
pi, mu, sig = 0.5, np.array([-0.2, 1.2]), np.array([1.0, 1.0])
ll = []
for _ in range(50):
d0 = (1 - pi) * stats.norm.pdf(y, mu[0], sig[0])
d1 = pi * stats.norm.pdf(y, mu[1], sig[1])
w = d1 / (d0 + d1)
pi = w.mean()
mu = np.array([(1 - w) @ y / (1 - w).sum(), w @ y / w.sum()])
sig = np.sqrt([((1 - w) * (y - mu[0])**2).sum() / (1 - w).sum(),
(w * (y - mu[1])**2).sum() / w.sum()])
ll.append(np.log(d0 + d1).sum())
pi, mu, sig, ll[-1](np.float64(0.3971956703149308),
array([-1.24267976, 2.09595405]),
array([0.76860609, 0.55888281]),
np.float64(-404.5923374138181))
Plot the fitted mixture and the monotone observed-data log likelihood.
grid = np.linspace(y.min() - 1, y.max() + 1, 300)
mix = (1 - pi) * stats.norm.pdf(grid, mu[0], sig[0]) + pi * stats.norm.pdf(grid, mu[1], sig[1])
fig, ax = plt.subplots(1, 2, figsize=(8, 3.4))
ax[0].hist(y, bins=30, density=True, alpha=0.55)
ax[0].plot(grid, mix, color="#ffcc66", lw=2.5)
ax[0].set(title="fitted mixture")
ax[1].plot(ll, marker="o", ms=3)
ax[1].set(title="monotone likelihood", xlabel="iteration")
plt.show()