I’m a bit confused about the behaviour of this program (it is not aiming to model anything useful).
import pyro import torch def pyro_model(true_data): p = pyro.sample("p", pyro.distributions.Uniform(0, 1)) b = pyro.sample( "b", pyro.distributions.RelaxedBernoulliStraightThrough( probs=p, temperature=torch.tensor(0.1) ), ) print(b) pyro.sample("c", pyro.distributions.Normal(b, 1), obs=true_data) true_data = torch.tensor([2.0]) mcmc_kernel = pyro.infer.mcmc.NUTS(pyro_model) mcmc = pyro.infer.MCMC( mcmc_kernel, num_samples=5, warmup_steps=0, ) mcmc.run(true_data)
Notice that there is a print statement after sampling from the
RelaxedBernoulliStraightThrough distribution. I would expect the distribution to return 0 or 1, but it is returning floating point numbers between 0 and 1. This only happens when it is run inside the MCMC method, but not when run directly.
Am I miss-understanding something here?
This is a problem for my use case since the more complex model I am trying to fit requires the return value of the distribution to be 0 or 1.