Why NUTS shows much better result than SVI on simple case?

I have implemented a simple case from “Bayesian Methods for Hackers” about the number of received messages.
Data was generated with the next parameters: size=100, Poisson1(20), Poisson2(50), and switch=70 day.
Estimations:

Poisson1, Poisson2, Switch
SVI: 20, 37, 51
NUTS: 20, 47, 70

Why NUTS is so much better? Should I improve the Guide function (tau)?

def daily_messages_guide(data):
    alpha_1 = pyro.param("alpha_1", torch.tensor(1.0),
                         constraint=constraints.positive)
    pyro.sample("lambda_1", dist.Exponential(alpha_1))
    
    alpha_2 = pyro.param("alpha_2", torch.tensor(1.0),
                         constraint=constraints.positive)
    pyro.sample("lambda_2", dist.Exponential(alpha_2))

    alpha0 = pyro.param("alpha0", torch.tensor(10.0))
    beta0 = pyro.param("beta0", torch.tensor(5.0))
    pyro.sample("tau", dist.Beta(alpha0, beta0))


def model(data):
    alpha = 1.0 / data.mean()
    lambda_1 = pyro.sample("lambda_1", dist.Exponential(alpha))
    lambda_2 = pyro.sample("lambda_2", dist.Exponential(alpha))
    
    tau = pyro.sample("tau", dist.Uniform(0, 1))
    lambda1_size = (tau * data.size(0) + 1).long()
    lambda2_size = data.size(0) - lambda1_size
    lambda_ = torch.cat([lambda_1.expand((lambda1_size,)),
                         lambda_2.expand((lambda2_size,))])

    with pyro.plate("data", data.size(0)):
        pyro.sample("obs", dist.Poisson(lambda_), obs=data)

this discussion is relevant.

you can’t use .long() in your model in conjuction with a gradient-based inference algorithm like SVI or HMC; this op blocks gradient flow.

1 Like

can you please explain where do I use .long() and why is it a problem? Maybe you can add a link to the documentation.
As you might see HMC works well. The issue is only the SVI estimation of the switch point.

you are trying to turn a continuous latent variable (tau) into an integer. this is a non-differentiable operation.

it may be that HMC is giving you results that are plausible in this particular case. however in order for HMC to generate samples that are actually (approximate) samples from the posterior, the log density of the model must be sufficiently smooth. if it is not, arbitrarily strange things may happen.

1 Like

Understand, thank you for your clarification. What to do in such cases? The superpower of Pyro is Universality:… Random Control Flow. I suppose every if/else statement will create non-differentiable operation?

it depends on the particular case. it’s hard to generalize because the space of all possible models is very large. in some cases discrete structure can be enumerated. in other cases reweighted wake sleep inference is a possibility. the list goes on…

I see. What about the case about the Number of received messages? What is your suggestion for the Guide function?

i don’t know the exact details of your modeling goal but i’d probably encode the switch as a discrete latent variable and enumerate it out using TraceEnum_ELBO

I have tried to enumerate it using TraceEnum_ELBO but it does not work. As a result, when_to_switch values are all the same.
I want to make it work as parallel. What is wrong with this code?

   def daily_messages_model_plate_enum(data):
    alpha = 1.0 / data.mean()
    lambda_1 = pyro.sample("lambda_1", dist.Exponential(alpha))
    lambda_2 = pyro.sample("lambda_2", dist.Exponential(alpha))
    
    # enumerate
    taus = pyro.sample("tau", dist.Categorical(logits=torch.ones(len(data))))

    for tau in taus:
      lambda1_size = tau
      lambda2_size = data.size(0) - lambda1_size
      lambda_ = torch.cat([lambda_1.expand((lambda1_size,)),
                          lambda_2.expand((lambda2_size,))])
      
      with pyro.plate("data_{}".format(tau), len(data)):
          pyro.sample("obs_{}".format(tau), dist.Poisson(lambda_), obs=data)


def daily_messages_guide_enum(data):
    # prior
    alpha_1 = pyro.param("alpha_1", torch.tensor(1.0),
                         constraint=constraints.positive)
    pyro.sample("lambda_1", dist.Exponential(alpha_1))
    
    alpha_2 = pyro.param("alpha_2", torch.tensor(1.0),
                         constraint=constraints.positive)
    pyro.sample("lambda_2", dist.Exponential(alpha_2))

    when_to_switch = pyro.param("when_to_switch", torch.ones(len(data)))
    pyro.sample("tau", dist.Categorical(logits=when_to_switch), infer={'enumerate': 'parallel'})

svi = SVI(daily_messages_model_plate_enum, daily_messages_guide_enum, optimizer, 
          loss=pyro.infer.TraceEnum_ELBO(max_plate_nesting=1))