Hi pyro forums,
Problem
I’m having trouble understanding why my inference sampling slows down when I use NUTS as I add independent random variables that are shaped the same as other random variables. I thought it should be vectorized in a way that shouldn’t exponentially add to computation time.
Context
I am trying to perform inference on multiple independent latent skills. For context each skill has two types, and each pair of skill and type has an array of alpha
shape parameters and an equal sized array of beta
scale parameters for a Beta distribution. Each pair of alpha[i]
and beta[i]
is an indepedent sample for each pair of observations for each pair of skill
and type
.
alpha
and beta
are usually two to four in length.
The model below runs quickly when skills_dict
only contains two to four pairs of skill and type, but once I expand this to 50 to 60 pairs, it slows exponentially. For example, for 500 samples:
50-60 items took 1hr
12 items took 3min
8 items took 1 min
4 items took 20 sec
2 items took 9 sec
I have found that performing individual HMC runs for each skill,type
pair looping in python to be faster than passing every skill,type
pair to pyro and performing HMC once. I 'm confused as to why this is the case.
Question
Is the way I have done the two nested for i in pyro.plate()
loops in the all_skills
function the most efficient way to perform inference on each independent pair of skill
and type
?
My model code is below, but I run: run_inference(all_skills,[skills_dict],n_samples=300)
def all_skills(skills_dict):
"""Model code to infer the overall distribution of the role requirements
for every skill.
Parameters
----------
skills_dict : dict of dict of dict oflist
Dictionary containing all alpha and beta parameters for each skill.
Structure is <skill>:<param>: <tensor of floats>
"""
skills = list(skills_dict.keys())
types = list(skills_dict[skills[0]].keys())
for i in pyro.plate("skills", len(skills)):
skill = skills[i]
for j in pyro.plate(skill + "types", len(types)):
type_ = types[j]
skill_given_skillSets(
skill + "_" + type_,
skills_dict[skill][type_]["alphas"],
skills_dict[skill][type_]["betas"],
)
def skill_given_skillSets(attributeIdAndType, alphas, betas):
"""Given an array of alphas and betas for a set of Beta
distributions for a given skillId, the model that infers
the overall distribution of the skill.
Parameters
----------
attributeIdAndType : str
The attributeId with type appended.
alphas : tensor of floats
The alpha parameters of the corresponding Beta distribution
betas : tensor of floats
The beta parameters of the corresponding Beta distribution
"""
obs = torch.tensor(1.0, requires_grad=False)
attribute = pyro.sample(attributeIdAndType, dist.Beta(prior_a, prior_b)) # prior
with pyro.plate(attributeIdAndType + "set_obs", len(alphas)):
prob = pyro.deterministic(
attributeIdAndType + "_prob", dist.Beta(alphas, betas).log_prob(attribute).exp()
)
pyro.sample(f"{attributeIdAndType}_obs", dist.Bernoulli(probs=prob), obs=obs)
def run_inference(
model, model_args, n_samples=1000, warmup=200, **model_kwargs,
):
"""
Generic function to input a complete model function and all model
inputs, then run MCMC over the model.
"""
from pyro.infer import MCMC, NUTS
nuts_kernel = NUTS(model) # initialise sampler
mcmc = MCMC(nuts_kernel, num_samples=n_samples, warmup_steps=warmup) # initialise mcmc object
mcmc.run(*model_args, **model_kwargs) # args feed into model args
hmc_samples = {k: v.detach().cpu().numpy() for k, v in mcmc.get_samples().items()}
return hmc_samples