# Relaxed Bernoulli Straight Through returns non-integer numbers when running MCMC

Hello,

I’m a bit confused about the behaviour of this program (it is not aiming to model anything useful).

``````import pyro
import torch

def pyro_model(true_data):
p = pyro.sample("p", pyro.distributions.Uniform(0, 1))
b = pyro.sample(
"b",
pyro.distributions.RelaxedBernoulliStraightThrough(
probs=p, temperature=torch.tensor(0.1)
),
)
print(b)
pyro.sample("c", pyro.distributions.Normal(b, 1), obs=true_data)

true_data = torch.tensor([2.0])
mcmc_kernel = pyro.infer.mcmc.NUTS(pyro_model)
mcmc = pyro.infer.MCMC(
mcmc_kernel,
num_samples=5,
warmup_steps=0,
)
mcmc.run(true_data)
``````

Notice that there is a print statement after sampling from the `RelaxedBernoulliStraightThrough` distribution. I would expect the distribution to return 0 or 1, but it is returning floating point numbers between 0 and 1. This only happens when it is run inside the MCMC method, but not when run directly.
Am I miss-understanding something here?

This is a problem for my use case since the more complex model I am trying to fit requires the return value of the distribution to be 0 or 1.

Thanks!

`RelaxedBernoulliStraightThrough` is only meant to be used with variational inference—afaik there is no reason you would ever use this with MCMC. the difference in behavior has to do with whether `sample` or `rsample` is called under the hood.

Thanks @martinjankowiak for the prompt response (as always!).
I am trying to use HMC, so I imagined that the Bernoulli distribution would present problems, since it is not differentiable, that is why I was using the StraightThrough version.

I have tried replacing it with a Bernoulli, but then I get into problems when I pass a tensor of probabilities:

``````import pyro
import torch

def pyro_model(true_data):
p = pyro.sample("p", pyro.distributions.Uniform(0, 1))
q = pyro.sample("q", pyro.distributions.Uniform(0, 1))
r = pyro.sample("r", pyro.distributions.Uniform(0, 1))
probs = torch.tensor([p, q, r])
b = pyro.sample(
"b",
pyro.distributions.Bernoulli(
probs=probs
)
)
print("b")
print(b)
pyro.sample("c", pyro.distributions.Normal(b, 1), obs=true_data)

true_data = torch.tensor([2.0, 3.0, 4.0])
mcmc_kernel = pyro.infer.mcmc.NUTS(pyro_model)
mcmc = pyro.infer.MCMC(
mcmc_kernel,
num_samples=5,
warmup_steps=0,
)
mcmc.run(true_data)
``````

Pyro now complains because the Bernoulli sampling returns `tensor([0,1])` rather than a 3-d tensor. I have been reading the docs and I understand that this is because enumeration, but I’m having a bit of a hard time understanding them. Could you illustrate me how I could make this example to work? Thanks!

for simple use cases like this i suggest using MixtureSameFamily; otherwise take a closer look at the docs

also afaik this may cause problems with gradients and such; you should use `cat`