Hi everyone,
I’am relatively new to pyro and I am trying to create a model using some Bernoulli distribution within different plate. I have been able to reproduce my error on the following snippet:
from pyro.infer import MCMC, NUTS
pyro.clear_param_store()
""" create a test to reproduce my error """
# create some random data
y = torch.randn(10, 25).to(dtype=torch.double)
# define model
def model_test(data):
with pyro.plate('G', 5):
vs = []
for i in range(5):
v = pyro.sample(f"bernoulli_{i}", dist.Bernoulli(0.3)) # fail
#v = pyro.sample(f"bernoulli_{i}", dist.ContinuousBernoulli(0.3)) # run
print(f"node {i}", v)
vs.append(v)
vs = torch.cat(vs, dim=0)
pyro.sample("x", dist.Normal(vs, 0.1), obs=data)
# try MCMC
nuts_kernel = NUTS(model_test)
init_mcmc = MCMC(nuts_kernel, num_samples=20, warmup_steps=0)
init_mcmc.run(y)
# look at the model graph and trace
trace = pyro.poutine.trace(model_test).get_trace(y)
pyro.render_model(model_test, model_args=(y,), filename="generated_graph.pdf")
When running the above, I am getting a strange behaviour that the Bernoulli distribution seems to grow:
bernoulli_0 dist 5 |
value 2 1 |
bernoulli_1 dist 5 |
value 2 1 1 |
bernoulli_2 dist 5 |
value 2 1 1 1 |
bernoulli_3 dist 5 |
value 2 1 1 1 1 |
bernoulli_4 dist 5 |
value 2 1 1 1 1 1 |
This is somehow strange to me as another distribution such as the ContinuousBernoulli does not produce this behaviour. I found only the following post which seems correlated but I did not understood the answer: Bernoulli does not sample in the amount of batch_size · Issue #2750 · pyro-ppl/pyro · GitHub
Could anyone help me figuring out how to fix my error?
Thanks in advance