Best way to vectorize using pyro.plate

Hi pyro forums,

Problem
I’m having trouble understanding why my inference sampling slows down when I use NUTS as I add independent random variables that are shaped the same as other random variables. I thought it should be vectorized in a way that shouldn’t exponentially add to computation time.

Context
I am trying to perform inference on multiple independent latent skills. For context each skill has two types, and each pair of skill and type has an array of alpha shape parameters and an equal sized array of beta scale parameters for a Beta distribution. Each pair of alpha[i] and beta[i] is an indepedent sample for each pair of observations for each pair of skill and type.

alpha and beta are usually two to four in length.

The model below runs quickly when skills_dict only contains two to four pairs of skill and type, but once I expand this to 50 to 60 pairs, it slows exponentially. For example, for 500 samples:

50-60 items took 1hr
12 items took 3min
8 items took 1 min
4 items took 20 sec
2 items took 9 sec

I have found that performing individual HMC runs for each skill,type pair looping in python to be faster than passing every skill,type pair to pyro and performing HMC once. I 'm confused as to why this is the case.

Question
Is the way I have done the two nested for i in pyro.plate() loops in the all_skills function the most efficient way to perform inference on each independent pair of skill and type?

My model code is below, but I run: run_inference(all_skills,[skills_dict],n_samples=300)

def all_skills(skills_dict):
    """Model code to infer the overall distribution of the role requirements
     for every skill.

    Parameters
    ----------
    skills_dict : dict of dict of dict oflist
        Dictionary containing all alpha and beta parameters for each skill.
        Structure is <skill>:<param>: <tensor of floats>
    """

    skills = list(skills_dict.keys())
    types = list(skills_dict[skills[0]].keys())
    for i in pyro.plate("skills", len(skills)):
        skill = skills[i]
        for j in pyro.plate(skill + "types", len(types)):
            type_ = types[j]
            skill_given_skillSets(
                skill + "_" + type_,
                skills_dict[skill][type_]["alphas"],
                skills_dict[skill][type_]["betas"],
            )

def skill_given_skillSets(attributeIdAndType, alphas, betas):
    """Given an array of alphas and betas for a set of Beta 
    distributions for a given skillId, the model that infers
     the overall distribution of the skill.

    Parameters
    ----------
    attributeIdAndType : str
        The attributeId with type appended.
    alphas : tensor of floats
        The alpha parameters of the corresponding Beta distribution
    betas : tensor of floats
        The beta parameters of the corresponding Beta distribution
    """
    obs = torch.tensor(1.0, requires_grad=False)
    attribute = pyro.sample(attributeIdAndType, dist.Beta(prior_a, prior_b))  # prior
    with pyro.plate(attributeIdAndType + "set_obs", len(alphas)):

        prob = pyro.deterministic(
            attributeIdAndType + "_prob", dist.Beta(alphas, betas).log_prob(attribute).exp()
        )
        pyro.sample(f"{attributeIdAndType}_obs", dist.Bernoulli(probs=prob), obs=obs)

def run_inference(
    model, model_args, n_samples=1000, warmup=200, **model_kwargs,
):
    """
    Generic function to input a complete model function and all model 
    inputs, then run MCMC over the model.
    """
    from pyro.infer import MCMC, NUTS

    nuts_kernel = NUTS(model)  # initialise sampler

    mcmc = MCMC(nuts_kernel, num_samples=n_samples, warmup_steps=warmup)  # initialise mcmc object
    mcmc.run(*model_args, **model_kwargs)  # args feed into model args
    hmc_samples = {k: v.detach().cpu().numpy() for k, v in mcmc.get_samples().items()}
    return hmc_samples

Hi @tdennisliu, note that for i in pyro.plate(...) is not vectorized but rather sequential, that’s resulting in slow inference. To vectorize a plate you’ll need to use the context manager form with pyro.plate(...) as i: where the resulting i is a tensor. See the tensor shapes tutorial for an introduction to nested vectorized plates.

In your case I would recommend converting the skills_dict to a torch.tensor(..., dtype=long) index, and avoid indexing skills[i]. This kind of integer array programming is common to most forms of vectorization.

1 Like

Thank you @fritzo, I knew I was building the probabilistic graph sequentially, but I had wrongly assumed the inference algorithm would know they were independent and wouldn’t be sequential but vectorized at the sampling stage, i.e. that sampling would not necessarily follow the way it was built by examining the graph structure.

Working with your suggestion to use with pyro.plate(...): is much faster and as expected.