To plot a trace obtained via MCMC sampling we do:
import arviz as az
data = az.from_pyro(mcmc)
az.plot_trace(data, compact=True);
where mcmc
is the fitted MCMC object from Pyro. Is there a way to do the plotting when the inference is done with SVI instead of MCMC?
Thanks!
I think you can use Predictive with model=guide
and params=svi.get_params(svi_state)
to get posterior samples. Then you can plot their histogram.
Finally figure it out! ArviZ’s from_pyro
only works with pyro.infer.MCMC
fitted object.
However, there is a workaround:
import arviz as az
predictive = Predictive(model, guide=guide, num_samples=500)
preds = predictive(x_test)
sanitized_preds = {k: v.unsqueeze(0).detach().numpy() for k, v in preds.items() if k != 'obs'}
pyro_data = az.convert_to_inference_data(sanitized_preds)
az.plot_trace(pyro_data, compact=True);
Then, it works!
Question: there is no mention of a solution like that (at least I couldn’t find) neither in ArviZ nor in Pyro documentation; shouldn’t that be communicated somewhere? If yes, where?
Cheers!
1 Like
I think the point of diagnostics is to see if the traces are good to use (e.g. to see if the samples are correlated). For SVI, posterior samples are independently generated from the guide, so no need to diagnose its posterior samples. Using ArviZ is unnecessary in my opinion.
1 Like