Using pyro.deterministic to create trace of variable

I am fairly new to Pyro and need your help on the following topic:

I am performing an HMC run and I would like to keep track of a variable’s trace that is not actively affecting the model’s likelihood. The sole reason of doing so is to have easy access to this variable’s trace after the run has been made. In my simple case (based on pyro.ai/examples), this is just the off-diagonal term of the Cholesky matrix, which I pass to “x”. When I use pyro.deterministic, I cannot see “x” in the collected posterior. Do you haev any solutions?

def model(y):
    d=y.shape[1]
    N=y.shape[0]
    theta = pyro.sample("theta", dist.HalfCauchy(torch.ones(d)))
    eta = torch.ones(1)
    L_omega = pyro.sample("L_omega", dist.LKJCorrCholesky(d, eta))
    L_Omega = torch.mm(torch.diag(theta.sqrt()), L_omega)
    mu = torch.zeros(d)

    x=pyro.deterministic('x',L_omega[0,1])

   with pyro.plate("observations", N):
      obs = pyro.sample("obs", dist.MultivariateNormal(mu, scale_tril=L_Omega), obs=y)
   return obs

Thank you in advance for your help.

see the docs for print_summary. you can pass in exclude_deterministic=False.

Thanks for your quick response. I get the following error:
AttributeError: ‘MCMC’ object has no attribute ‘print_summary’

Also, would print_sumamry(exclude_deterministic=False) lead to the posterior having the trace of the deterministic varaible?

You can find my script below:

import torch
import pyro
import pyro.distributions as dist
from pyro.infer.mcmc.api import MCMC
from pyro.infer.mcmc import NUTS

def model(y):
    d=y.shape[1]
    N=y.shape[0]
    theta = pyro.sample("theta", dist.HalfCauchy(torch.ones(d)))
    eta = torch.ones(1)
    L_omega = pyro.sample("L_omega", dist.LKJCorrCholesky(d, eta))
    L_Omega = torch.mm(torch.diag(theta.sqrt()), L_omega)
    mu = torch.zeros(d)
    with pyro.plate("observations", N):
        obs = pyro.sample("obs", dist.MultivariateNormal(mu, scale_tril=L_Omega), obs=y)
    return obs

if __name__ == "__main__":
    y = torch.randn(size=(100, 2))
    nuts_kernel = NUTS(model=model)
    mcmc = MCMC(kernel=nuts_kernel, num_samples=100, num_chains=1, warmup_steps=100)
    mcmc.run(y)
    mcmc.print_summary(exclude_deterministic=False)
    posterior=mcmc.get_samples(100)

oh sorry i thought you were using numpyro. i’d recommend you try doing this in numpyro it’ll be much faster.

I am using Windows and have trouble installing jax, so numpyro is not an option. Is there no way to achieve this using pyro?

@fehiepsi do you know what the behavior on the pyro side is? are the deterministic sites filtered out somewhere?

In Pyro, we need to use Predictive as in this example to get values of any sites (through return_sites argument).