Amav
December 5, 2020, 7:47pm
1
I am fairly new to Pyro and need your help on the following topic:
I am performing an HMC run and I would like to keep track of a variable’s trace that is not actively affecting the model’s likelihood. The sole reason of doing so is to have easy access to this variable’s trace after the run has been made. In my simple case (based on pyro.ai/examples), this is just the off-diagonal term of the Cholesky matrix, which I pass to “x”. When I use pyro.deterministic, I cannot see “x” in the collected posterior. Do you haev any solutions?
def model(y):
d=y.shape[1]
N=y.shape[0]
theta = pyro.sample("theta", dist.HalfCauchy(torch.ones(d)))
eta = torch.ones(1)
L_omega = pyro.sample("L_omega", dist.LKJCorrCholesky(d, eta))
L_Omega = torch.mm(torch.diag(theta.sqrt()), L_omega)
mu = torch.zeros(d)
x=pyro.deterministic('x',L_omega[0,1])
with pyro.plate("observations", N):
obs = pyro.sample("obs", dist.MultivariateNormal(mu, scale_tril=L_Omega), obs=y)
return obs
Thank you in advance for your help.
see the docs for print_summary
. you can pass in exclude_deterministic=False
.
Amav
December 6, 2020, 11:00am
3
martinjankowiak:
exclude_deterministic
Thanks for your quick response. I get the following error:
AttributeError: ‘MCMC’ object has no attribute ‘print_summary’
Also, would print_sumamry(exclude_deterministic=False)
lead to the posterior
having the trace of the deterministic varaible?
You can find my script below:
import torch
import pyro
import pyro.distributions as dist
from pyro.infer.mcmc.api import MCMC
from pyro.infer.mcmc import NUTS
def model(y):
d=y.shape[1]
N=y.shape[0]
theta = pyro.sample("theta", dist.HalfCauchy(torch.ones(d)))
eta = torch.ones(1)
L_omega = pyro.sample("L_omega", dist.LKJCorrCholesky(d, eta))
L_Omega = torch.mm(torch.diag(theta.sqrt()), L_omega)
mu = torch.zeros(d)
with pyro.plate("observations", N):
obs = pyro.sample("obs", dist.MultivariateNormal(mu, scale_tril=L_Omega), obs=y)
return obs
if __name__ == "__main__":
y = torch.randn(size=(100, 2))
nuts_kernel = NUTS(model=model)
mcmc = MCMC(kernel=nuts_kernel, num_samples=100, num_chains=1, warmup_steps=100)
mcmc.run(y)
mcmc.print_summary(exclude_deterministic=False)
posterior=mcmc.get_samples(100)
oh sorry i thought you were using numpyro. i’d recommend you try doing this in numpyro it’ll be much faster.
Amav
December 6, 2020, 8:30pm
5
I am using Windows and have trouble installing jax
, so numpyro is not an option. Is there no way to achieve this using pyro
?
@fehiepsi do you know what the behavior on the pyro side is? are the deterministic sites filtered out somewhere?
In Pyro, we need to use Predictive as in this example to get values of any sites (through return_sites
argument).