Help with enumeration with Bernoulli distribution


I’m a bit stuck trying to figure out how to make my model to work. To give you an idea, it is a simple disease model. We have 10 people, each with a random number (representing their susceptibility to infection). Then we infect a certain amount of people (in this test case just the first one). We have a parameter, ß, which takes into account the infectivty per contact (and is the one I’m looking to infer). The vector of probabilities of each individual getting infected is then given by 1-\exp{(-\beta \vec S \sum_i \, I_i )}, where I_i in this case is just a vector of 1 and 0 to denote who is infected or not. I then run two iterations of this model.

The first thing I do is to create some fake data to fit to:

import pyro
import torch

# true parameter values we want to infer
true_log_beta = -0.2
true_beta = 10**true_log_beta # infectivity per contact
n = 10 # number of people
S = torch.rand(n) # individual susceptibilities

def get_data():
    infected = torch.zeros(n)
    infected[0] = 1.0
    alpha1 = 1.0 - torch.exp(-true_beta * S * infected.sum())
    b1 = pyro.distributions.Bernoulli(alpha1).sample()
    not_infected = 1.0 - infected
    infected = infected + not_infected * b1
    alpha2 = 1.0 - torch.exp(-true_beta * S * infected.sum())
    b2 = pyro.distributions.Bernoulli(alpha2).sample()
    return b2
data = get_data()

Then I setup my Pyro model:

def model(data):
    log_beta = pyro.sample("log_beta", pyro.distributions.Normal(0,1))
    beta = torch.pow(10, log_beta)
    infected = torch.zeros(n)
    infected[0] = 1.0
    alpha1 = 1.0 - torch.exp(-beta * S * infected.sum())
    b1 = pyro.sample("b1", pyro.distributions.Bernoulli(alpha1)) # Shape should be (n,)
    not_infected = 1.0 - infected
    infected = infected + not_infected * b1
    alpha2 = 1.0 - torch.exp(-beta * S * infected.sum())
    pyro.sample("b2", pyro.distributions.Bernoulli(alpha2), obs=data)

and inference with NUTS:

def run_mcmc(model, data, **kwargs):
    nuts_kernel = pyro.infer.NUTS(model)
    mcmc = pyro.infer.MCMC(nuts_kernel, **kwargs)
    return mcmc

mcmc = run_mcmc(model, data, num_samples=10, warmup_steps=10)

Now as you may have already noticed this does not work properly because I’m not doing enumeration correctly (I guess). The “b1” sample statement returns [[0],[1]], since they are the possible values of the Bernoulli distribution, but I do not know how to make this work with the structure I have, since I would expect a vector of 1s and 0s of length 10. I’ve gone through the Pyro tutorials (Inference with Discrete Latent Variables — Pyro Tutorials 1.8.1 documentation and Tensor shapes in Pyro — Pyro Tutorials 1.8.1 documentation) but I still can’t figure it out. Any help would be appreciated!

1 Like