Inhomogeneous total count not supported by enumerate_support for Binomial distribution

I have a model where one of my discrete sites is sampled from a Binomial distribution where both total_count and probs are also latent variables. Unfortunately, pyro doesn’t seem to support total_count being inhomogenous and I was wondering if there’s any current workaround for this issue.

Here’s the full code to my model:

@config_enumerate
def model(states0=None, data=None, trans_mat=None, state_prior=None):
    with ignore_jit_warnings():
        if data is not None:
            num_loci, num_samples = data.shape
        elif states0 is not None:
            num_loci, num_samples = states0.shape
        assert num_samples is not None
        assert num_loci is not None
    
    # negative binomial dispersion
    nb_r = pyro.param('expose_nb_r', torch.tensor([10000.0]), constraint=constraints.positive)

    with pyro.plate('num_samples', num_samples):

        u = pyro.sample('expose_u', dist.Normal(torch.tensor([70.]), torch.tensor([10.])))
        
        # starting state for markov chain
        if states0 is None:
            state = 2

        for l in pyro.markov(range(num_loci)):

            # sample states using HMM structure
            if states0 is None:
                temp_state_prob = trans_mat[state]
                if state_prior is not None:
                    temp_state_prob = temp_state_prob * state_prior[l]
                state = pyro.sample("state_{}".format(l), dist.Categorical(temp_state_prob),
                                 infer={"enumerate": "parallel"})
            else:
                # no need to sample state when true value provided
                state = states0[l]

            # probability of doubling for each bin
            p_doub = pyro.sample('expose_p_doub_{}'.format(l), dist.Beta(torch.tensor([1.]), torch.tensor([1.])))

            # determine how many states at this bin have doubled
            doub = pyro.sample('doub_{}'.format(l), dist.Binomial(state, p_doub))
            
            # total number of states after accounting for doubling
            total_state = state + doub

            # transform units for negative binomial sampling
            expected_obs = (u * total_state)
            nb_p = expected_obs / (expected_obs + nb_r)
            
            if data is not None:
                obs = data[l]
            else:
                obs = None
            
            full_obs = pyro.sample('obs_{}'.format(l), dist.NegativeBinomial(nb_r, probs=nb_p), obs=obs)

Note that if I swap out the two lines for getting doub and total_state to a sampling scheme where each bin is either fully doubled or not doubled at all, the model works fine.

doub = pyro.sample('doub_{}'.format(l), dist.Bernoulli(p_doub))
total_state = state * (1. + doub)

I’ve also been able to get my model to run by fixing state at each bin (using the states0 argument) and switching the num_samples plate to a for-loop; however, said version of the model runs very slow and I would like to avoid fixing state.

Any thoughts here?

For posterity, I think my issue is somehow related to this open pytorch issue on multinomial sampling torch.distributions.multinomial.Multinomial cannot be used in batch · Issue #42407 · pytorch/pytorch · GitHub

One workaround is to split your batched Binomial into multiple Binomials each of which has a homogeneous total_count. You can either do this naively by splitting into a independent Binomials, or try to do this cleverly by collating by total_count; I’d recommend the former until you hit compute cost ussues.

Another workaround is to embed your Binomial into a Categorical whose size is an upper bound on total_count. Then you can set the Categorical logits based on Binomial log_probs and enumerate homogeneously. Here’s an attempt at this version:

binomial = Binomial(total_count.unsqueeze(-1), probs.unsqueeze(-1))
values = torch.arange(total_count.max().item() + 1)
upper_bounds = total_count.unsqueeze(-1)
cat_logits = binomial.log_prob(torch.min(values, upper_bounds))
cat_logits = cat_logits.masked_fill(values > upper_bound, -math.inf)
z = pyro.sample("z", Categorical(logits=cat_logits),
                infer={"enumerate": "parallel"})
z = z.float()  # if you want to mimic the Binomial interface

Note the categorical samples will have dtype torch.long rather than Binomial’s torch.float.

The deep reason that Pyro can’t enumerate over heterogeneous total count is that the enumeration tensors have shape that depends on total_count and we don’t support ragged tensors.