Sampling from model prior in numpyro

How to sample from the model’s prior distribution?

Code :

import jax.numpy as jnp
from jax import random, vmap
from jax.scipy.special import logsumexp
import numpy as np

import numpyro
import numpyro.distributions as dist
from numpyro.diagnostics import hpdi
from numpyro.infer import MCMC, NUTS, Predictive


x1 = random.normal(random.PRNGKey(144), [27])
A = random.normal(random.PRNGKey(3), [25, 27])

y = logsumexp((A - x1), axis=1)


def model(A, y=None):
    n = A.shape[1]
    with numpyro.plate("x", n):
        x = numpyro.sample("x", dist.Normal())
    mu = logsumexp(A - x, axis=1)
    y_obs = numpyro.sample("y_obs", dist.MultivariateNormal(loc=mu, covariance_matrix=1*jnp.eye(A.shape)), obs=y)
        

Predictive(model, num_samples=1, batch_ndims=1)(A)

Error:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-4-39c07926c018> in <module>
----> 1 Predictive(model, num_samples=1, batch_ndims=1)(A)

~/anaconda3/envs/numpyro/lib/python3.8/site-packages/numpyro/infer/util.py in __call__(self, rng_key, *args, **kwargs)
    585                                             model_args=args, model_kwargs=kwargs)
    586         model = substitute(self.model, self.params)
--> 587         return _predictive(rng_key, model, posterior_samples, self._batch_shape,
    588                            return_sites=self.return_sites, parallel=self.parallel,
    589                            model_args=args, model_kwargs=kwargs)

~/anaconda3/envs/numpyro/lib/python3.8/site-packages/numpyro/infer/util.py in _predictive(rng_key, model, posterior_samples, batch_shape, return_sites, parallel, model_args, model_kwargs)
    486     if num_samples > 1:
    487         rng_key = random.split(rng_key, num_samples)
--> 488     rng_key = rng_key.reshape(batch_shape + (2,))
    489     chunk_size = num_samples if parallel else 1
    490     return soft_vmap(single_prediction, (rng_key, posterior_samples), len(batch_shape), chunk_size)

~/anaconda3/envs/numpyro/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py in _reshape_method(a, *newshape, **kwargs)
   1251           type(newshape[0]) is not Poly):
   1252     newshape = newshape[0]
-> 1253   return _reshape(a, newshape, order=order)
   1254 
   1255 

~/anaconda3/envs/numpyro/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py in _reshape(a, newshape, order)
   1229   computed_newshape = _compute_newshape(a, newshape)
   1230   if order == "C":
-> 1231     return lax.reshape(a, computed_newshape, None)
   1232   elif order == "F":
   1233     dims = np.arange(ndim(a))[::-1]

~/anaconda3/envs/numpyro/lib/python3.8/site-packages/jax/_src/lax/lax.py in reshape(operand, new_sizes, dimensions)
    703     return operand
    704   else:
--> 705     return reshape_p.bind(
    706       operand, new_sizes=new_sizes,
    707       dimensions=None if dimensions is None or same_dims else tuple(dimensions))

~/anaconda3/envs/numpyro/lib/python3.8/site-packages/jax/core.py in bind(self, *args, **params)
    268     top_trace = find_top_trace(args)
    269     tracers = map(top_trace.full_raise, args)
--> 270     out = top_trace.process_primitive(self, tracers, params)
    271     return map(full_lower, out) if self.multiple_results else full_lower(out)
    272 

~/anaconda3/envs/numpyro/lib/python3.8/site-packages/jax/core.py in process_primitive(self, primitive, tracers, params)
    578 
    579   def process_primitive(self, primitive, tracers, params):
--> 580     return primitive.impl(*tracers, **params)
    581 
    582   def process_call(self, primitive, f, tracers, params):

~/anaconda3/envs/numpyro/lib/python3.8/site-packages/jax/_src/lax/lax.py in _reshape_impl(operand, new_sizes, dimensions)
   3464         lazy_expr = lazy.broadcast(operand._lazy_expr, new_sizes, bcast_dims)
   3465       return xla.make_device_array(aval, operand._device, lazy_expr, operand.device_buffer)
-> 3466   return xla.apply_primitive(reshape_p, operand, new_sizes=new_sizes,
   3467                              dimensions=dimensions)
   3468 

~/anaconda3/envs/numpyro/lib/python3.8/site-packages/jax/interpreters/xla.py in apply_primitive(prim, *args, **params)
    233 def apply_primitive(prim, *args, **params):
    234   """Impl rule that compiles and runs a single primitive 'prim' using XLA."""
--> 235   compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args), **params)
    236   return compiled_fun(*args)
    237 

~/anaconda3/envs/numpyro/lib/python3.8/site-packages/jax/interpreters/xla.py in xla_primitive_callable(prim, *arg_specs, **params)
    258     return _xla_callable(lu.wrap_init(prim_fun), device, None, "prim", donated_invars,
    259                          *arg_specs)
--> 260   aval_out = prim.abstract_eval(*avals, **params)
    261   if not prim.multiple_results:
    262     handle_result = aval_to_result_handler(device, aval_out)

~/anaconda3/envs/numpyro/lib/python3.8/site-packages/jax/_src/lax/lax.py in standard_abstract_eval(prim, shape_rule, dtype_rule, *args, **kwargs)
   1992     out_avals = safe_map(ConcreteArray, out_vals)
   1993   elif least_specialized is ShapedArray:
-> 1994     shapes, dtypes = shape_rule(*args, **kwargs), dtype_rule(*args, **kwargs)
   1995     if not prim.multiple_results:
   1996       shapes, dtypes = [shapes], [dtypes]

~/anaconda3/envs/numpyro/lib/python3.8/site-packages/jax/_src/lax/lax.py in _reshape_shape_rule(operand, new_sizes, dimensions)
   3493   if prod(np.shape(operand)) != prod(new_sizes):
   3494     msg = 'reshape total size must be unchanged, got new_sizes {} for shape {}.'
-> 3495     raise TypeError(msg.format(new_sizes, np.shape(operand)))
   3496   if dimensions is not None:
   3497     if set(dimensions) != set(range(np.ndim(operand))):

TypeError: reshape total size must be unchanged, got new_sizes (1, 2) for shape (25, 27).

Hi @w0rlord, in Predictive, you need to provide a random key, like in this tutorial. For example, Predictive(...)(random.PRNGKey(0), A). For a single prediction, you can use batch_ndims=0. The result with batch_ndims=1 will have an additional singleton dimension.

1 Like