# Best way to vectorize using pyro.plate

Hi pyro forums,

Problem
I’m having trouble understanding why my inference sampling slows down when I use NUTS as I add independent random variables that are shaped the same as other random variables. I thought it should be vectorized in a way that shouldn’t exponentially add to computation time.

Context
I am trying to perform inference on multiple independent latent skills. For context each skill has two types, and each pair of skill and type has an array of `alpha` shape parameters and an equal sized array of `beta` scale parameters for a Beta distribution. Each pair of `alpha[i]` and `beta[i]` is an indepedent sample for each pair of observations for each pair of `skill` and `type`.

`alpha` and `beta` are usually two to four in length.

The model below runs quickly when `skills_dict` only contains two to four pairs of skill and type, but once I expand this to 50 to 60 pairs, it slows exponentially. For example, for 500 samples:

50-60 items took 1hr
12 items took 3min
8 items took 1 min
4 items took 20 sec
2 items took 9 sec

I have found that performing individual HMC runs for each `skill,type` pair looping in python to be faster than passing every `skill,type` pair to pyro and performing HMC once. I 'm confused as to why this is the case.

Question
Is the way I have done the two nested `for i in pyro.plate()` loops in the `all_skills` function the most efficient way to perform inference on each independent pair of `skill` and `type`?

My model code is below, but I run: `run_inference(all_skills,[skills_dict],n_samples=300)`

``````def all_skills(skills_dict):
"""Model code to infer the overall distribution of the role requirements
for every skill.

Parameters
----------
skills_dict : dict of dict of dict oflist
Dictionary containing all alpha and beta parameters for each skill.
Structure is <skill>:<param>: <tensor of floats>
"""

skills = list(skills_dict.keys())
types = list(skills_dict[skills[0]].keys())
for i in pyro.plate("skills", len(skills)):
skill = skills[i]
for j in pyro.plate(skill + "types", len(types)):
type_ = types[j]
skill_given_skillSets(
skill + "_" + type_,
skills_dict[skill][type_]["alphas"],
skills_dict[skill][type_]["betas"],
)

def skill_given_skillSets(attributeIdAndType, alphas, betas):
"""Given an array of alphas and betas for a set of Beta
distributions for a given skillId, the model that infers
the overall distribution of the skill.

Parameters
----------
attributeIdAndType : str
The attributeId with type appended.
alphas : tensor of floats
The alpha parameters of the corresponding Beta distribution
betas : tensor of floats
The beta parameters of the corresponding Beta distribution
"""
attribute = pyro.sample(attributeIdAndType, dist.Beta(prior_a, prior_b))  # prior
with pyro.plate(attributeIdAndType + "set_obs", len(alphas)):

prob = pyro.deterministic(
attributeIdAndType + "_prob", dist.Beta(alphas, betas).log_prob(attribute).exp()
)
pyro.sample(f"{attributeIdAndType}_obs", dist.Bernoulli(probs=prob), obs=obs)

def run_inference(
model, model_args, n_samples=1000, warmup=200, **model_kwargs,
):
"""
Generic function to input a complete model function and all model
inputs, then run MCMC over the model.
"""
from pyro.infer import MCMC, NUTS

nuts_kernel = NUTS(model)  # initialise sampler

mcmc = MCMC(nuts_kernel, num_samples=n_samples, warmup_steps=warmup)  # initialise mcmc object
mcmc.run(*model_args, **model_kwargs)  # args feed into model args
hmc_samples = {k: v.detach().cpu().numpy() for k, v in mcmc.get_samples().items()}
return hmc_samples``````

Hi @tdennisliu, note that `for i in pyro.plate(...)` is not vectorized but rather sequential, that’s resulting in slow inference. To vectorize a plate you’ll need to use the context manager form `with pyro.plate(...) as i:` where the resulting `i` is a tensor. See the tensor shapes tutorial for an introduction to nested vectorized plates.

In your case I would recommend converting the `skills_dict` to a `torch.tensor(..., dtype=long)` index, and avoid indexing `skills[i]`. This kind of integer array programming is common to most forms of vectorization.

1 Like

Thank you @fritzo, I knew I was building the probabilistic graph sequentially, but I had wrongly assumed the inference algorithm would know they were independent and wouldn’t be sequential but vectorized at the sampling stage, i.e. that sampling would not necessarily follow the way it was built by examining the graph structure.

Working with your suggestion to use `with pyro.plate(...):` is much faster and as expected.