Discrete uniform from categorical leads to error during inference

I am experiencing some difficulties with a simple model from Bayesian Methods for Hackers:

def model(obs_data):
    n = obs_data.shape[0]

    # hyperparams
    alpha = 1 / obs_data.mean()

    # 0 ...(λ1)...τ...(λ2)... n #

    with pyro.plate('lam_t', 2):
        lam_t = pyro.sample("lam", dist.Exponential(alpha))

    tau = 1 + pyro.sample("tau", dist.Categorical(T.ones(n-1) / (n-1)))
    
    lam = T.cat([
        lam_t[0].expand(tau),
        lam_t[1].expand(n - tau)
    ])

    with pyro.plate("data", len(obs_data)) as idx:
        obs = pyro.sample("obs", dist.Poisson(lam), obs=obs_data)

# Inference
kernel = NUTS(
    model, jit_compile=True, ignore_jit_warnings=True, max_tree_depth=5
)
posterior = MCMC(
    kernel, num_samples=2000, warmup_steps=500
)
posterior.run(obs_data)

During inference, the following exception is raised:

     14     lam = T.cat([
---> 15         lam_t[0].expand(tau),
     16         lam_t[1].expand(n - tau)
     17     ])
TypeError: 
expand(): argument 'size' (position 1) must be tuple of ints, not Tensor

It appears that the problem is fixed if I sample tau from dist.Uniform(0, 1) and make the appropriate transformation.

Therefore, my question reduces to: is there a more subtle difference (at least regarding the size of the sampled data) between

tau = 1 + pyro.sample("tau", dist.Categorical(T.ones(n-1) / (n-1)))

and

tau = 1 + ((n-1) * pyro.sample("tau", dist.Uniform(0, 1))).long()

Thank you!

Hi @AlexD, the difference is Categorical is a discrete distribution and Uniform is a continuous distribution. NUTS algorithm only works for continuous latent variables so if you use Categorical distribution, you need to either:

  • use NUTS within Gibbs (this is the default in PyMC3 I guess)
  • or marginalize the discrete latent variable (this is the default in Pyro, see more in enumerate tutorial on how to write code to support marginalization).

If you want to use NUTS within Gibbs, you can try DiscreteHMCGibbs in NumPyro.

1 Like