Core Contribution
Baum-Welch is EM for hidden Markov models. Hidden states \(S_t\) follow a Markov chain with transition matrix \(A\) , while observations \(Y_t\) are emitted from state-specific distributions \(B\) . The likelihood sums over all hidden paths:
\[
p(y_{1:T})=\sum_{s_{1:T}} \pi_{s_1}B_{s_1y_1}\prod_{t=2}^T A_{s_{t-1}s_t}B_{s_ty_t}.
\]
The forward-backward algorithm computes smoothed probabilities
\[
\gamma_t(i)=\Pr(S_t=i\mid y_{1:T})
\]
and expected transitions \(\xi_t(i,j)\) ; the M-step normalizes these expected counts to update \(A\) and \(B\) .
Minimal Implementation
Simulate hidden states \(S_t\) and observed symbols \(Y_t\) from a two-state HMM.
T = 80
A_true = np.array([[0.88 , 0.12 ], [0.18 , 0.82 ]])
B_true = np.array([[0.80 , 0.20 ], [0.25 , 0.75 ]])
s = np.zeros(T, dtype= int ); y = np.zeros(T, dtype= int )
s[0 ] = rng.integers(2 ); y[0 ] = rng.choice(2 , p= B_true[s[0 ]])
for t in range (1 , T):
s[t] = rng.choice(2 , p= A_true[s[t- 1 ]])
y[t] = rng.choice(2 , p= B_true[s[t]])
y[:12 ]
array([1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 0])
Initialize \(A\) , \(B\) , and \(\pi\) , then run forward-backward to get \(\alpha_t\) , \(\beta_t\) , \(\gamma_t\) , and expected transitions \(\xi_t\) .
A = np.array([[0.65 , 0.35 ], [0.35 , 0.65 ]], float )
B = np.array([[0.6 , 0.4 ], [0.4 , 0.6 ]], float )
pi0 = np.array([0.5 , 0.5 ])
for _ in range (30 ):
alpha = np.zeros((T, 2 )); scale = np.zeros(T)
alpha[0 ] = pi0 * B[:, y[0 ]]; scale[0 ] = alpha[0 ].sum (); alpha[0 ] /= scale[0 ]
for t in range (1 , T):
alpha[t] = alpha[t- 1 ] @ A * B[:, y[t]]
scale[t] = alpha[t].sum (); alpha[t] /= scale[t]
beta = np.ones((T, 2 ))
for t in range (T - 2 , - 1 , - 1 ):
beta[t] = A @ (B[:, y[t+ 1 ]] * beta[t+ 1 ])
beta[t] /= beta[t].sum ()
gamma = alpha * beta
gamma /= gamma.sum (axis= 1 , keepdims= True )
xi_sum = np.zeros((2 , 2 ))
for t in range (T - 1 ):
xi = alpha[t, :, None ] * A * B[:, y[t+ 1 ]][None , :] * beta[t+ 1 ][None , :]
xi_sum += xi / xi.sum ()
A = xi_sum / xi_sum.sum (axis= 1 , keepdims= True )
B[:, 0 ] = gamma[y == 0 ].sum (axis= 0 ); B[:, 1 ] = gamma[y == 1 ].sum (axis= 0 )
B /= B.sum (axis= 1 , keepdims= True )
A, B
(array([[0.83005533, 0.16994467],
[0.26779448, 0.73220552]]),
array([[0.89947822, 0.10052178],
[0.16834231, 0.83165769]]))
Plot \(\Pr(S_t=1\mid y_{1:T})\) against the true hidden state.
fig, ax = plt.subplots(figsize= (7 , 3.2 ))
ax.step(np.arange(T), s, where= "mid" , alpha= 0.45 , label= "true hidden state" )
ax.plot(gamma[:, 1 ], color= "#ffcc66" , lw= 2.3 , label= "Pr(state=1 | data)" )
ax.scatter(np.arange(T), y + 0.03 , s= 12 , alpha= 0.45 , label= "observed symbol" )
ax.set (ylim= (- 0.1 , 1.1 ), xlabel= "time" )
ax.legend()
plt.show()