Simplex in model become nan

I’m building a really simple HMM with Gaussian emission.

import pyro
import pyro.distributions as dist

from pyro import poutine
from pyro.optim import Adam

import torch
import torch.nn as nn
from torch.distributions import constraints

def model(observations, num_state):
    
    assert not torch._C._get_tracing_state()

    with poutine.mask(mask = True):

        p_transition = pyro.sample("p_transition",
                                    dist.Dirichlet((1 / num_state) * torch.ones(num_state, num_state)).to_event(1))

        p_emission = pyro.sample("p_emission",
                                    dist.Normal(0, 2).expand([num_state, 1]).to_event(2))

        p_init = pyro.sample("p_init",
                                dist.Dirichlet((1 / num_state) * torch.ones(num_state)))

    
    # initial state
    current_state = pyro.sample("x_{}".format(0),
                                dist.Categorical(p_init),
                                infer = {"enumerate" : "parallel"})

    pyro.sample("y_{}".format(0), dist.Normal(p_emission[current_state.squeeze(-1)], 2), obs = observations[0])
    
    # later states
    for t in pyro.markov(range(1, len(observations))):

        current_state = pyro.sample("x_{}".format(t),
                                    dist.Categorical(p_transition[current_state]),
                                    infer = {"enumerate" : "parallel"})

        pyro.sample("y_{}".format(t),
                    dist.Normal(p_emission[current_state], 2),
                    obs = observations[t])

My data is generated by

import numpy as np
import torch
import matplotlib.pyplot as plt
import pyro

from pyro.infer import SVI, JitTraceEnum_ELBO, TraceEnum_ELBO, TraceTMC_ELBO
from pyro.infer.autoguide import AutoDelta
from pyro.optim import Adam
from scipy.stats import multinomial
from typing import List

def equilibrium_distribution(p_transition):
    n_states = p_transition.shape[0]
    A = np.append(
        arr=p_transition.T - np.eye(n_states),
        values=np.ones(n_states).reshape(1, -1),
        axis=0
    )
    b = np.transpose(np.array([0] * n_states + [1]))
    p_eq = np.linalg.solve(
        a=np.transpose(A).dot(A),
        b=np.transpose(A).dot(b)
    )
    return p_eq

def markov_sequence(p_init, p_transition, sequence_length):

    if p_init is None:
        p_init = equilibrium_distribution(p_transition)
        
    initial_state = list(multinomial.rvs(1, p_init)).index(1)

    states = [initial_state]
    for _ in range(sequence_length - 1):
        p_tr = p_transition[states[-1]]
        new_state = list(multinomial.rvs(1, p_tr)).index(1)
        states.append(new_state)
    return states

def get_obs(states, mus, sigmas):

    emissions = []

    for s in states:

        val = np.random.normal(mus[s], sigmas[s], 1)
        emissions.append(val[0])

    return emissions

p = np.array([[0.2, 0.6, 0.2],
              [0.4, 0.2, 0.4],
              [0.1, 0.5, 0.4]])

state_seq = markov_sequence(None, p, 100)

obs = np.array(list(get_obs(state_seq, [1, 10, -10], [2, 2, 2]))).reshape(-1,1)
obs = torch.from_numpy(obs)
obs = obs.to(torch.float)

The fit for this model

guide = AutoDelta(poutine.block(model, expose_fn = lambda msg : msg["name"].startswith("p_")))

Elbo = TraceTMC_ELBO
elbo = Elbo(max_plate_nesting = 1)

optim = Adam({"lr": 0.05})

svi = SVI(model, guide, optim, elbo)

losses = []

for i in range(0, 1000):

    loss = svi.step(observations = obs, num_state = 3)
    losses.append(loss)

    #print(pyro.get_param_store()['AutoDelta.p_init'])

    if i % 100 == 0:
        
        print("iter " + str(i))

        print(pyro.get_param_store()['AutoDelta.p_init'])

plt.plot(losses)
plt.show()

The parameter that becomes nan is p_init.
As each 100 iteration, the value of p_init

iter 0
tensor([0.3220, 0.3220, 0.3559], grad_fn=<DivBackward0>)
iter 100
tensor([1.6889e-05, 1.3304e-05, 9.9997e-01], grad_fn=<DivBackward0>)
iter 200
tensor([5.3326e-10, 3.7257e-10, 1.0000e+00], grad_fn=<DivBackward0>)
iter 300
tensor([1.9880e-14, 1.3042e-14, 1.0000e+00], grad_fn=<DivBackward0>)
iter 400
tensor([7.9135e-19, 4.9804e-19, 1.0000e+00], grad_fn=<DivBackward0>)
iter 500
tensor([3.2631e-23, 1.9926e-23, 1.0000e+00], grad_fn=<DivBackward0>)
iter 600
tensor([1.3754e-27, 8.2068e-28, 1.0000e+00], grad_fn=<DivBackward0>)
iter 700
tensor([5.8850e-32, 3.4471e-32, 1.0000e+00], grad_fn=<DivBackward0>)
iter 800
tensor([2.5453e-36, 1.4686e-36, 1.0000e+00], grad_fn=<DivBackward0>)

One of the probability is getting really close to 1 and finally get to nan, but the probability of initial state shouldn’t be like this.

The error message is

ValueError: Expected parameter probs (Tensor of shape (3,)) of distribution Categorical(probs: torch.Size([3])) to satisfy the constraint Simplex(), but found invalid values:
tensor([nan, nan, nan], grad_fn=<DivBackward0>)
    Trace Shapes:      
     Param Sites:      
    Sample Sites:      
p_transition dist | 3 3
            value | 3 3
  p_emission dist | 3 1
            value | 3 1
      p_init dist | 3  
            value | 3  

Hi, thanks for the easily reproducible example! One small bug: you forgot to add .squeeze(-1) when you sample y_t with t>0.
I think a part of the problem is that when you use AutoDelta, you try to find the MAP estimate for p_init, while you have only one data point y_0 to inform this estimate. Therefore the most likely p_init is simply the simplex with a 1 at the index for which p_emission is closest to y_0. A simplex like (1, 0, 0) probably leads to numerical issues.

One way to fix this is not use AutoDelta for p_init. For instance, you could do

guide1 = AutoDelta(poutine.block(model, expose=["p_transition", "p_emission"]))
guide2 = AutoMultivariateNormal(poutine.block(model, expose=["p_init"]))
guide = AutoGuideList(model)
guide.append(guide1)
guide.append(guide2)

to get the MAP for p_emission and p_transition, but a variational estimate for p_init. I tried this, and there seems to be another problem, because I do not get good estimates for p_transition. The estimated matrix contains very small values.