Hello there! I’m new to Pyro and trying to code up an ar-HMM with Gaussian emissions, that can be trained on multiple sequences with different lengths using DiscreteHMM. This is the model:
init_probs ~ Dirichlet
transition_probs ~ Dirichlet
weights ~ Normal(0, 0.1)
biases ~ Normal(kmeans_centers, 0.1)
z_0 ~ Cat(init_probs)
z_t ~ Cat(transition_probs[z_{t-1}]
x_t ~ Normal(weights[z_t] * x_{t-1} + biases[z_t]
The main issue is that training doesn’t seem to improve from the initial param values, and training to convergence just leads to the model predicting one state for everything as the learned transition matrix has very high self-transition probs for only 1 state. I’m lost as to what the issue could be here.
Does anyone have an idea what the issue in training could be, or how I might better debug it.
Here’s an outline of my model code (full code):
# Compute kmeans centers to initialize biases
def init_kmeans():
# aggregate all observations
concat_X = X.reshape(-1, n)
km = KMeans(k).fit(concat_X)
centers = torch.tensor(km.cluster_centers_, dtype=torch.float32)
return centers
kmeans_centers = init_kmeans()
def model(X, lengths, k=5, batch_size=None, kmeans_centers):
"""Defines an ar-HMM with k latent states."""
# Get parameters
num_seqs, t, n = X.shape # n is the number of emission dims
# Sample parameters according to priors
initial_p = pyro.sample("initial_probs", dist.Dirichlet(torch.ones(k) * 10))
transition_p = pyro.sample("transition_probs", pyro.sample("transition_probs", dist.Dirichlet(torch.ones((self.k, self.k))).to_event(1))
with pyro.plate("coefficients", k):
weights = pyro.sample("weights", dist.Normal(torch.zeros((k, n)), 0.1).to_event(1))
biases = pyro.sample("biases", dist.Normal(kmeans_centers, 0.1).to_event(1))
# Sample the latent states and observations
sequences_plate = pyro.plate("sequences", num_seqs, batch_size)
with sequences_plate as batch:
lengths = lengths[batch]
X_batch = X[batch, : lengths.max()]
timepoints = torch.arange(X_batch.size(1))
loc = weights * X_batch.unsqueeze(2) + biases
assert loc.shape == torch.Size([batch_size, lengths.max(), k, n])
# Variance of emission distribution
scale = torch.ones(k, 1) * 0.1
# Move emissions into event dimension
obs_dist = dist.Normal(loc, scale).to_event(1)
# Mask shape must be broadcastable with obs_dist batch shape
obs_dist = obs_dist.mask((timepoints < lengths.unsqueeze(-1)).unsqueeze(-1))
hmm_dist = dist.DiscreteHMM(torch.log(initial_p), torch.log(transition_p), obs_dist)
pyro.sample("x", hmm_dist, obs=X_batch)
def train(self, n_steps=1000, lr=0.01, seed=None, early_stopping=100):
if seed is not None:
pyro.set_rng_seed(seed)
guide = AutoDelta(self.model)
optim = Adam({"lr": lr})
elbo = Trace_ELBO()
svi = SVI(model, guide, optim, elbo)
bar = tqdm(range(n_steps), total=n_steps)
losses = np.zeros(n_steps)
min_loss = math.inf
for step in bar:
loss = svi.step()
losses[step] = loss
bar.set_description(f"loss: {loss:.2e}")
# Early stopping
if early_stopping is None:
continue
if loss < min_loss:
min_loss = min(loss, min_loss)
early = 0
else:
early += 1
if early > early_stopping:
self.losses = self.losses[:step]
break
return losses