I have a model where one of my discrete sites is sampled from a Binomial distribution where both total_count
and probs
are also latent variables. Unfortunately, pyro doesn’t seem to support total_count
being inhomogenous and I was wondering if there’s any current workaround for this issue.
Here’s the full code to my model:
@config_enumerate
def model(states0=None, data=None, trans_mat=None, state_prior=None):
with ignore_jit_warnings():
if data is not None:
num_loci, num_samples = data.shape
elif states0 is not None:
num_loci, num_samples = states0.shape
assert num_samples is not None
assert num_loci is not None
# negative binomial dispersion
nb_r = pyro.param('expose_nb_r', torch.tensor([10000.0]), constraint=constraints.positive)
with pyro.plate('num_samples', num_samples):
u = pyro.sample('expose_u', dist.Normal(torch.tensor([70.]), torch.tensor([10.])))
# starting state for markov chain
if states0 is None:
state = 2
for l in pyro.markov(range(num_loci)):
# sample states using HMM structure
if states0 is None:
temp_state_prob = trans_mat[state]
if state_prior is not None:
temp_state_prob = temp_state_prob * state_prior[l]
state = pyro.sample("state_{}".format(l), dist.Categorical(temp_state_prob),
infer={"enumerate": "parallel"})
else:
# no need to sample state when true value provided
state = states0[l]
# probability of doubling for each bin
p_doub = pyro.sample('expose_p_doub_{}'.format(l), dist.Beta(torch.tensor([1.]), torch.tensor([1.])))
# determine how many states at this bin have doubled
doub = pyro.sample('doub_{}'.format(l), dist.Binomial(state, p_doub))
# total number of states after accounting for doubling
total_state = state + doub
# transform units for negative binomial sampling
expected_obs = (u * total_state)
nb_p = expected_obs / (expected_obs + nb_r)
if data is not None:
obs = data[l]
else:
obs = None
full_obs = pyro.sample('obs_{}'.format(l), dist.NegativeBinomial(nb_r, probs=nb_p), obs=obs)
Note that if I swap out the two lines for getting doub
and total_state
to a sampling scheme where each bin is either fully doubled or not doubled at all, the model works fine.
doub = pyro.sample('doub_{}'.format(l), dist.Bernoulli(p_doub))
total_state = state * (1. + doub)
I’ve also been able to get my model to run by fixing state
at each bin (using the states0
argument) and switching the num_samples
plate to a for-loop; however, said version of the model runs very slow and I would like to avoid fixing state
.
Any thoughts here?