Hi,
I am just getting my feet wet with PPL’s and Pyro. One of the things I would like to do is to compare models. I saw that arviz supports arviz.compare, but I am hitting a snag. The setup I am using below is akin to:
# Get priors and posterios for the first model
prior_predictive1 = Predictive(model1, num_samples=DRAWS)
prior_predictions1 = prior_predictive1(X=X1_train, y=None)
posterior_samples1 = mcmc1.get_samples(num_samples=DRAWS)
posterior_predictive1 = Predictive(model1, posterior_samples1)
posterior_predictions1 = posterior_predictive1(X=X1_train, y=None)
# Get priors and posteriors for the second model
prior_predictive2 = Predictive(model2, num_samples=DRAWS)
prior_predictions2 = prior_predictive2(X=X2_train, y=None)
posterior_samples2 = mcmc2.get_samples(num_samples=DRAWS)
posterior_predictive2 = Predictive(model2, posterior_samples2)
posterior_predictions2 = posterior_predictive2(X=X2_train, y=None)
# Get inference data
inference_data1 = az.from_pyro(mcmc1, prior=prior_predictions1, posterior_predictive=posterior_predictions1)
inference_data2 = az.from_pyro(mcmc2, prior=prior_predictions2, posterior_predictive=posterior_predictions2)
compare_dict = {"m1": inference_data1,
"m2": inference_data2}
# Make comparison
az.compare(compare_dict)
The error that arviz is giving is:
/Projects/pyro/pyro-3.9/lib/python3.9/site-packages/arviz/data/io_pyro.py:157: UserWarning: Could not get vectorized trace, log_likelihood group will be omitted. Check your model vectorization or set log_likelihood=False
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
File ~/Projects/pyro/pyro-3.9/lib/python3.9/site-packages/arviz/stats/stats.py:248, in compare(dataset_dict, ic, method, b_samples, alpha, seed, scale, var_name)
245 try:
246 # Here is where the IC function is actually computed -- the rest of this
247 # function is argument processing and return value formatting
--> 248 ics = ics.append([ic_func(dataset, pointwise=True, scale=scale, var_name=var_name)])
249 except Exception as e:
File ~/Projects/pyro/pyro-3.9/lib/python3.9/site-packages/arviz/stats/stats.py:657, in loo(data, pointwise, var_name, reff, scale)
656 inference_data = convert_to_inference_data(data)
--> 657 log_likelihood = _get_log_likelihood(inference_data, var_name=var_name)
658 pointwise = rcParams["stats.ic_pointwise"] if pointwise is None else pointwise
File ~/Projects/pyro/pyro-3.9/lib/python3.9/site-packages/arviz/stats/stats_utils.py:426, in get_log_likelihood(idata, var_name)
425 if not hasattr(idata, "log_likelihood"):
--> 426 raise TypeError("log likelihood not found in inference data object")
427 if var_name is None:
TypeError: log likelihood not found in inference data object
This looks to be an Arviz related issue, but I thought I check here to see if there are work arounds, and/or there is another preferred way to handle this.