Plotting trace from SVI

To plot a trace obtained via MCMC sampling we do:

import arviz as az

data = az.from_pyro(mcmc)
az.plot_trace(data, compact=True);

where mcmc is the fitted MCMC object from Pyro. Is there a way to do the plotting when the inference is done with SVI instead of MCMC?

Thanks!

I think you can use Predictive with model=guide and params=svi.get_params(svi_state) to get posterior samples. Then you can plot their histogram.

Finally figure it out! ArviZ’s from_pyro only works with pyro.infer.MCMC fitted object.
However, there is a workaround:

import arviz as az

predictive = Predictive(model, guide=guide, num_samples=500)
preds = predictive(x_test)
sanitized_preds = {k: v.unsqueeze(0).detach().numpy() for k, v in preds.items() if k != 'obs'}
pyro_data = az.convert_to_inference_data(sanitized_preds)
az.plot_trace(pyro_data, compact=True);

Then, it works!

Question: there is no mention of a solution like that (at least I couldn’t find) neither in ArviZ nor in Pyro documentation; shouldn’t that be communicated somewhere? If yes, where?

Cheers!

1 Like

I think the point of diagnostics is to see if the traces are good to use (e.g. to see if the samples are correlated). For SVI, posterior samples are independently generated from the guide, so no need to diagnose its posterior samples. Using ArviZ is unnecessary in my opinion.

1 Like

Hmmm makes sense… thx!