Sampling from Deterministic sites using Predictive

Hi there,

I have a model where there are deterministic transformations of random variables that I would like to sample from. I’ve followed the advice from here and here, but I get an Type error, even in a simple example. Thanks in advance!

A simple example:

def model(y,n=20):
Define the likelihood and the data input, state the priors
p = pyro.sample("p", pyro.distributions.Beta(2,2)) #prior
test = pyro.deterministic('test', p-1)
with pyro.plate("data",len(y)):
    #likelihood and data statement
    pyro.sample("obs",pyro.distributions.Binomial(n,p), obs=y)

from pyro.infer import MCMC, NUTS, Predictive
import torch
import numpy as np
import pandas as pd

p = 0.7

#simulate the data
y = np.random.binomial(20,p, size=100 )
y_tensor = torch.tensor(y, dtype=torch.float) 
nuts_kernel = NUTS(model)

mcmc = MCMC(nuts_kernel, num_samples=1000, warmup_steps=200) #initialise mcmc object #args feed into model args

predict = Predictive(model,hmc_samples, num_samples=100)(y_tensor).get_samples()

The above gives me this in the traceback:

AttributeError                            Traceback (most recent call last)
<ipython-input-14-f3fe4ee26164> in <module>
      1 from pyro.infer import Predictive
----> 3 predict = Predictive(model,hmc_samples, num_samples=100)(y_tensor).get_samples()
      4 predict

/anaconda3/lib/python3.6/site-packages/torch/nn/modules/ in _call_impl(self, *input, **kwargs)
    725             result = self._slow_forward(*input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(
    729                 _global_forward_hooks.values(),

/anaconda3/lib/python3.6/site-packages/pyro/infer/ in forward(self, *args, **kwargs)
    203                                             parallel=self.parallel, model_args=args, model_kwargs=kwargs)
    204         return _predictive(self.model, posterior_samples, self.num_samples, return_sites=return_sites,
--> 205                            parallel=self.parallel, model_args=args, model_kwargs=kwargs)
    207     def get_samples(self, *args, **kwargs):

/anaconda3/lib/python3.6/site-packages/pyro/infer/ in _predictive(model, posterior_samples, num_samples, return_sites, return_trace, parallel, model_args, model_kwargs)
     90     if not parallel:
     91         return _predictive_sequential(model, posterior_samples, model_args, model_kwargs, num_samples,
---> 92                                       return_site_shapes, return_trace=False)
     94     trace = poutine.trace(poutine.condition(vectorize(model), reshaped_samples))\

/anaconda3/lib/python3.6/site-packages/pyro/infer/ in _predictive_sequential(model, posterior_samples, model_args, model_kwargs, num_samples, return_site_shapes, return_trace)
     36     samples = [{k: v[i] for k, v in posterior_samples.items()} for i in range(num_samples)]
     37     for i in range(num_samples):
---> 38         trace = poutine.trace(poutine.condition(model, samples[i])).get_trace(*model_args, **model_kwargs)
     39         if return_trace:
     40             collected.append(trace)

/anaconda3/lib/python3.6/site-packages/pyro/poutine/ in get_trace(self, *args, **kwargs)
    185         Calls this poutine and returns its trace instead of the function's return value.
    186         """
--> 187         self(*args, **kwargs)
    188         return self.msngr.get_trace()

/anaconda3/lib/python3.6/site-packages/pyro/poutine/ in __call__(self, *args, **kwargs)
    163                                       args=args, kwargs=kwargs)
    164             try:
--> 165                 ret = self.fn(*args, **kwargs)
    166             except (ValueError, RuntimeError) as e:
    167                 exc_type, exc_value, traceback = sys.exc_info()

/anaconda3/lib/python3.6/site-packages/pyro/poutine/ in _context_wrap(context, fn, *args, **kwargs)
     10 def _context_wrap(context, fn, *args, **kwargs):
     11     with context:
---> 12         return fn(*args, **kwargs)

<ipython-input-5-3e7c54a24bea> in model(y, n)
      6     """
      7     p = pyro.sample("p", pyro.distributions.Beta(2,2)) #prior
----> 8     test = pyro.deterministic('test', p-1)
      9     with pyro.plate("data",len(y)):
     10         #likelihood and data statement

/anaconda3/lib/python3.6/site-packages/pyro/ in deterministic(name, value, event_dim)
    146     """
    147     event_dim = value.ndim if event_dim is None else event_dim
--> 148     return sample(name, dist.Delta(value, event_dim=event_dim).mask(False),
    149                   obs=value, infer={"_deterministic": True})

/anaconda3/lib/python3.6/site-packages/pyro/distributions/ in __call__(cls, *args, **kwargs)
     16             if result is not None:
     17                 return result
---> 18         return super().__call__(*args, **kwargs)
     20     @property

/anaconda3/lib/python3.6/site-packages/pyro/distributions/ in __init__(self, v, log_density, event_dim, validate_args)
     31     def __init__(self, v, log_density=0.0, event_dim=0, validate_args=None):
---> 32         if event_dim > v.dim():
     33             raise ValueError('Expected event_dim <= v.dim(), actual {} vs {}'.format(event_dim, v.dim()))
     34         batch_dim = v.dim() - event_dim

AttributeError: 'numpy.float64' object has no attribute 'dim'

Actually, it seems if I remove the posterior samples hmc_samples from the input, I can get my deterministic site:

predict = Predictive(model, num_samples=100)(y_tensor)

gives me the deterministic site!

I would still like to know why the Prediction fails if I provide the posterior samples, as described in the links I gave.

The error says 'numpy.float64' object has no attribute 'dim'. So probably you used numpy ndarray somewhere. What is your hmc_samples?

1 Like

I had grabbed them using this code snippet. I guess I shouldn’t have used .numpy()?

The Docs for Predictive only suggest a dict for the posterior_samples argument, I incorrectly didn’t think it needed tensors too.

hmc_samples = { k: v.detach().cpu().numpy() for k, v in mcmc.get_samples().items() }