Training Issues with DiscreteHMM for autoregressive model

Hello there! I’m new to Pyro and trying to code up an ar-HMM with Gaussian emissions, that can be trained on multiple sequences with different lengths using DiscreteHMM. This is the model:

init_probs ~ Dirichlet
transition_probs ~ Dirichlet
weights ~ Normal(0, 0.1)
biases ~ Normal(kmeans_centers, 0.1)

z_0 ~ Cat(init_probs)
z_t ~ Cat(transition_probs[z_{t-1}]
x_t ~ Normal(weights[z_t] * x_{t-1} + biases[z_t]

The main issue is that training doesn’t seem to improve from the initial param values, and training to convergence just leads to the model predicting one state for everything as the learned transition matrix has very high self-transition probs for only 1 state. I’m lost as to what the issue could be here.

Does anyone have an idea what the issue in training could be, or how I might better debug it.

Here’s an outline of my model code (full code):


# Compute kmeans centers to initialize biases
def init_kmeans():
    # aggregate all observations
    concat_X = X.reshape(-1, n)
    km = KMeans(k).fit(concat_X)
    centers = torch.tensor(km.cluster_centers_, dtype=torch.float32)
    return centers

kmeans_centers = init_kmeans()

def model(X, lengths, k=5, batch_size=None, kmeans_centers):
    """Defines an ar-HMM with k latent states."""
    # Get parameters

    num_seqs, t, n = X.shape # n is the number of emission dims

    # Sample parameters according to priors

    initial_p = pyro.sample("initial_probs", dist.Dirichlet(torch.ones(k) * 10))

    transition_p = pyro.sample("transition_probs", pyro.sample("transition_probs", dist.Dirichlet(torch.ones((self.k, self.k))).to_event(1))

    with pyro.plate("coefficients", k):
        weights = pyro.sample("weights", dist.Normal(torch.zeros((k, n)), 0.1).to_event(1))
        biases = pyro.sample("biases", dist.Normal(kmeans_centers, 0.1).to_event(1))

    # Sample the latent states and observations

    sequences_plate = pyro.plate("sequences", num_seqs, batch_size)

    with sequences_plate as batch:
        lengths = lengths[batch]
        X_batch = X[batch, : lengths.max()]
        timepoints = torch.arange(X_batch.size(1))

        loc = weights * X_batch.unsqueeze(2) + biases

        assert loc.shape == torch.Size([batch_size, lengths.max(), k, n])

        # Variance of emission distribution
        scale = torch.ones(k, 1) * 0.1

        # Move emissions into event dimension
        obs_dist = dist.Normal(loc, scale).to_event(1)

        # Mask shape must be broadcastable with obs_dist batch shape
        obs_dist = obs_dist.mask((timepoints < lengths.unsqueeze(-1)).unsqueeze(-1))

        hmm_dist = dist.DiscreteHMM(torch.log(initial_p), torch.log(transition_p), obs_dist)

        pyro.sample("x", hmm_dist, obs=X_batch)


def train(self, n_steps=1000, lr=0.01, seed=None, early_stopping=100):
    if seed is not None:
        pyro.set_rng_seed(seed)

    guide = AutoDelta(self.model)
    optim = Adam({"lr": lr})
    elbo = Trace_ELBO()
    svi = SVI(model, guide, optim, elbo)

    bar = tqdm(range(n_steps), total=n_steps)
    losses = np.zeros(n_steps)
    min_loss = math.inf
    for step in bar:
        loss = svi.step()
        losses[step] = loss
        bar.set_description(f"loss: {loss:.2e}")
        # Early stopping
        if early_stopping is None:
            continue
        if loss < min_loss:
            min_loss = min(loss, min_loss)
            early = 0
        else:
            early += 1
        if early > early_stopping:
            self.losses = self.losses[:step]
            break

        return losses

as constructed your model will run init_kmeans() again and again at each iteration. do initialization outside of the model and pass in whatever you need.

Thanks for your quick reply Martin! Actually, in the full code this isn’t the case - that’s my mistake when I tried to condense the code for the forum post. Apologies for the confusion and this is how the model is structured as part of a class. k-means is computed only once in the init method.


@property
def biases(self):
    return pyro.sample("biases", dist.Normal(self.kmeans_centers, 0.1).to_event(1))

...

def model(self):
        # Sample parameters according to priors

        initial_p = self.initial_probs
        transition_p = self.transition_probs

        with pyro.plate("coefficients", self.k):
            weights = self.weights
            biases = self.biases

        # Sample the latent states and observations

        sequences_plate = pyro.plate("sequences", self.num_seqs, self.batch_size)

        with sequences_plate as batch:
            lengths = self.lengths[batch]
            X_batch = self.X[batch, : lengths.max()]
            timepoints = torch.arange(X_batch.size(1))

            loc = weights * X_batch.unsqueeze(2) + biases

            assert loc.shape == torch.Size(
                [self.batch_size, lengths.max(), self.k, self.n]
            )

            # Variance of emission distribution
            scale = torch.ones(self.k, 1) * 0.1

            # Move emissions into event dimension
            obs_dist = dist.Normal(loc, scale).to_event(1)

            # Mask shape must be broadcastable with obs_dist batch shape
            obs_dist = obs_dist.mask((timepoints < lengths.unsqueeze(-1)).unsqueeze(-1))

            hmm_dist = dist.DiscreteHMM(
                torch.log(initial_p), torch.log(transition_p), obs_dist
            )

            pyro.sample("x", hmm_dist, obs=X_batch)

i have no idea what’s happening in your model but it’s possible that fixing a small scale may be problematic