Hello,
I’m a bit confused about the behaviour of this program (it is not aiming to model anything useful).
import pyro
import torch
def pyro_model(true_data):
p = pyro.sample("p", pyro.distributions.Uniform(0, 1))
b = pyro.sample(
"b",
pyro.distributions.RelaxedBernoulliStraightThrough(
probs=p, temperature=torch.tensor(0.1)
),
)
print(b)
pyro.sample("c", pyro.distributions.Normal(b, 1), obs=true_data)
true_data = torch.tensor([2.0])
mcmc_kernel = pyro.infer.mcmc.NUTS(pyro_model)
mcmc = pyro.infer.MCMC(
mcmc_kernel,
num_samples=5,
warmup_steps=0,
)
mcmc.run(true_data)
Notice that there is a print statement after sampling from the RelaxedBernoulliStraightThrough
distribution. I would expect the distribution to return 0 or 1, but it is returning floating point numbers between 0 and 1. This only happens when it is run inside the MCMC method, but not when run directly.
Am I miss-understanding something here?
This is a problem for my use case since the more complex model I am trying to fit requires the return value of the distribution to be 0 or 1.
Thanks!