Help with enumeration with Bernoulli distribution

Hello!

I’m a bit stuck trying to figure out how to make my model to work. To give you an idea, it is a simple disease model. We have 10 people, each with a random number (representing their susceptibility to infection). Then we infect a certain amount of people (in this test case just the first one). We have a parameter, ß, which takes into account the infectivty per contact (and is the one I’m looking to infer). The vector of probabilities of each individual getting infected is then given by 1-\exp{(-\beta \vec S \sum_i \, I_i )}, where I_i in this case is just a vector of 1 and 0 to denote who is infected or not. I then run two iterations of this model.

The first thing I do is to create some fake data to fit to:

import pyro
import torch

# true parameter values we want to infer
true_log_beta = -0.2
true_beta = 10**true_log_beta # infectivity per contact
n = 10 # number of people
S = torch.rand(n) # individual susceptibilities

def get_data():
    infected = torch.zeros(n)
    infected[0] = 1.0
    alpha1 = 1.0 - torch.exp(-true_beta * S * infected.sum())
    b1 = pyro.distributions.Bernoulli(alpha1).sample()
    not_infected = 1.0 - infected
    infected = infected + not_infected * b1
    alpha2 = 1.0 - torch.exp(-true_beta * S * infected.sum())
    b2 = pyro.distributions.Bernoulli(alpha2).sample()
    return b2
data = get_data()

Then I setup my Pyro model:

def model(data):
    log_beta = pyro.sample("log_beta", pyro.distributions.Normal(0,1))
    beta = torch.pow(10, log_beta)
    infected = torch.zeros(n)
    infected[0] = 1.0
    alpha1 = 1.0 - torch.exp(-beta * S * infected.sum())
    b1 = pyro.sample("b1", pyro.distributions.Bernoulli(alpha1)) # Shape should be (n,)
    not_infected = 1.0 - infected
    infected = infected + not_infected * b1
    alpha2 = 1.0 - torch.exp(-beta * S * infected.sum())
    pyro.sample("b2", pyro.distributions.Bernoulli(alpha2), obs=data)

and inference with NUTS:

def run_mcmc(model, data, **kwargs):
    pyro.clear_param_store()
    nuts_kernel = pyro.infer.NUTS(model)
    mcmc = pyro.infer.MCMC(nuts_kernel, **kwargs)
    mcmc.run(data)
    return mcmc

mcmc = run_mcmc(model, data, num_samples=10, warmup_steps=10)

Now as you may have already noticed this does not work properly because I’m not doing enumeration correctly (I guess). The “b1” sample statement returns [[0],[1]], since they are the possible values of the Bernoulli distribution, but I do not know how to make this work with the structure I have, since I would expect a vector of 1s and 0s of length 10. I’ve gone through the Pyro tutorials (Inference with Discrete Latent Variables — Pyro Tutorials 1.8.4 documentation and Tensor shapes in Pyro — Pyro Tutorials 1.8.4 documentation) but I still can’t figure it out. Any help would be appreciated!

2 Likes

I recently ran into this exact issue. I suspect the fact that this is not really possible is intentional, as the sampling space here grows exponentially with the number of Bernoulli random variables being enumerated over. After all, any combination of 0’s and 1’s is possible.

For what it’s worth, using the TraceGraph ELBO estimator worked decently well as an alternative to enumeration for me.

1 Like