Relaxed Bernoulli Straight Through returns non-integer numbers when running MCMC

Hello,

I’m a bit confused about the behaviour of this program (it is not aiming to model anything useful).

import pyro
import torch

def pyro_model(true_data):
    p = pyro.sample("p", pyro.distributions.Uniform(0, 1))
    b = pyro.sample(
        "b",
        pyro.distributions.RelaxedBernoulliStraightThrough(
            probs=p, temperature=torch.tensor(0.1)
        ),
    )
    print(b)
    pyro.sample("c", pyro.distributions.Normal(b, 1), obs=true_data)


true_data = torch.tensor([2.0])
mcmc_kernel = pyro.infer.mcmc.NUTS(pyro_model)
mcmc = pyro.infer.MCMC(
    mcmc_kernel,
    num_samples=5,
    warmup_steps=0,
)
mcmc.run(true_data)

Notice that there is a print statement after sampling from the RelaxedBernoulliStraightThrough distribution. I would expect the distribution to return 0 or 1, but it is returning floating point numbers between 0 and 1. This only happens when it is run inside the MCMC method, but not when run directly.
Am I miss-understanding something here?

This is a problem for my use case since the more complex model I am trying to fit requires the return value of the distribution to be 0 or 1.

Thanks!

RelaxedBernoulliStraightThrough is only meant to be used with variational inference—afaik there is no reason you would ever use this with MCMC. the difference in behavior has to do with whether sample or rsample is called under the hood.

Thanks @martinjankowiak for the prompt response (as always!).
I am trying to use HMC, so I imagined that the Bernoulli distribution would present problems, since it is not differentiable, that is why I was using the StraightThrough version.

I have tried replacing it with a Bernoulli, but then I get into problems when I pass a tensor of probabilities:

import pyro
import torch


def pyro_model(true_data):
    p = pyro.sample("p", pyro.distributions.Uniform(0, 1))
    q = pyro.sample("q", pyro.distributions.Uniform(0, 1))
    r = pyro.sample("r", pyro.distributions.Uniform(0, 1))
    probs = torch.tensor([p, q, r])
    b = pyro.sample(
        "b",
        pyro.distributions.Bernoulli(
            probs=probs
        )
    )
    print("b")
    print(b)
    pyro.sample("c", pyro.distributions.Normal(b, 1), obs=true_data)


true_data = torch.tensor([2.0, 3.0, 4.0])
mcmc_kernel = pyro.infer.mcmc.NUTS(pyro_model)
mcmc = pyro.infer.MCMC(
    mcmc_kernel,
    num_samples=5,
    warmup_steps=0,
)
mcmc.run(true_data)

Pyro now complains because the Bernoulli sampling returns tensor([0,1]) rather than a 3-d tensor. I have been reading the docs and I understand that this is because enumeration, but I’m having a bit of a hard time understanding them. Could you illustrate me how I could make this example to work? Thanks!

for simple use cases like this i suggest using MixtureSameFamily; otherwise take a closer look at the docs

also afaik this may cause problems with gradients and such; you should use cat