Statistical Ideas in Code
  • Home
  • Papers
    • Random Effects
    • False Discovery Rate
    • Two Cultures
    • Bootstrap
    • EM Algorithm
    • Baum-Welch
    • MCMC
    • Cox Model
    • Propensity Scores
    • GAM
    • Conformal Prediction
    • Gradient Boosting
    • Marginal Structural Models
    • Lasso

On this page

  • Core Contribution
  • Minimal Implementation
  • Implementations

The Baum-Welch Algorithm

Leonard Baum and Lloyd Welch (1970)

Core Contribution

Baum-Welch is EM for hidden Markov models. Hidden states \(S_t\) follow a Markov chain with transition matrix \(A\), while observations \(Y_t\) are emitted from state-specific distributions \(B\). The likelihood sums over all hidden paths:

\[ p(y_{1:T})=\sum_{s_{1:T}} \pi_{s_1}B_{s_1y_1}\prod_{t=2}^T A_{s_{t-1}s_t}B_{s_ty_t}. \]

The forward-backward algorithm computes smoothed probabilities

\[ \gamma_t(i)=\Pr(S_t=i\mid y_{1:T}) \]

and expected transitions \(\xi_t(i,j)\); the M-step normalizes these expected counts to update \(A\) and \(B\).

Minimal Implementation

Simulate hidden states \(S_t\) and observed symbols \(Y_t\) from a two-state HMM.

T = 80
A_true = np.array([[0.88, 0.12], [0.18, 0.82]])
B_true = np.array([[0.80, 0.20], [0.25, 0.75]])
s = np.zeros(T, dtype=int); y = np.zeros(T, dtype=int)
s[0] = rng.integers(2); y[0] = rng.choice(2, p=B_true[s[0]])
for t in range(1, T):
    s[t] = rng.choice(2, p=A_true[s[t-1]])
    y[t] = rng.choice(2, p=B_true[s[t]])
y[:12]
array([1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 0])

Initialize \(A\), \(B\), and \(\pi\), then run forward-backward to get \(\alpha_t\), \(\beta_t\), \(\gamma_t\), and expected transitions \(\xi_t\).

A = np.array([[0.65, 0.35], [0.35, 0.65]], float)
B = np.array([[0.6, 0.4], [0.4, 0.6]], float)
pi0 = np.array([0.5, 0.5])

for _ in range(30):
    alpha = np.zeros((T, 2)); scale = np.zeros(T)
    alpha[0] = pi0 * B[:, y[0]]; scale[0] = alpha[0].sum(); alpha[0] /= scale[0]
    for t in range(1, T):
        alpha[t] = alpha[t-1] @ A * B[:, y[t]]
        scale[t] = alpha[t].sum(); alpha[t] /= scale[t]
    beta = np.ones((T, 2))
    for t in range(T - 2, -1, -1):
        beta[t] = A @ (B[:, y[t+1]] * beta[t+1])
        beta[t] /= beta[t].sum()
    gamma = alpha * beta
    gamma /= gamma.sum(axis=1, keepdims=True)
    xi_sum = np.zeros((2, 2))
    for t in range(T - 1):
        xi = alpha[t, :, None] * A * B[:, y[t+1]][None, :] * beta[t+1][None, :]
        xi_sum += xi / xi.sum()
    A = xi_sum / xi_sum.sum(axis=1, keepdims=True)
    B[:, 0] = gamma[y == 0].sum(axis=0); B[:, 1] = gamma[y == 1].sum(axis=0)
    B /= B.sum(axis=1, keepdims=True)
A, B
(array([[0.83005533, 0.16994467],
        [0.26779448, 0.73220552]]),
 array([[0.89947822, 0.10052178],
        [0.16834231, 0.83165769]]))

Plot \(\Pr(S_t=1\mid y_{1:T})\) against the true hidden state.

fig, ax = plt.subplots(figsize=(7, 3.2))
ax.step(np.arange(T), s, where="mid", alpha=0.45, label="true hidden state")
ax.plot(gamma[:, 1], color="#ffcc66", lw=2.3, label="Pr(state=1 | data)")
ax.scatter(np.arange(T), y + 0.03, s=12, alpha=0.45, label="observed symbol")
ax.set(ylim=(-0.1, 1.1), xlabel="time")
ax.legend()
plt.show()

Forward-backward probabilities recover a hidden two-state path.

Implementations

hmmlearn, pomegranate, depmixS4