How to sample from the model’s prior distribution?
Code :
import jax.numpy as jnp
from jax import random, vmap
from jax.scipy.special import logsumexp
import numpy as np
import numpyro
import numpyro.distributions as dist
from numpyro.diagnostics import hpdi
from numpyro.infer import MCMC, NUTS, Predictive
x1 = random.normal(random.PRNGKey(144), [27])
A = random.normal(random.PRNGKey(3), [25, 27])
y = logsumexp((A - x1), axis=1)
def model(A, y=None):
n = A.shape[1]
with numpyro.plate("x", n):
x = numpyro.sample("x", dist.Normal())
mu = logsumexp(A - x, axis=1)
y_obs = numpyro.sample("y_obs", dist.MultivariateNormal(loc=mu, covariance_matrix=1*jnp.eye(A.shape)), obs=y)
Predictive(model, num_samples=1, batch_ndims=1)(A)
Error:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-4-39c07926c018> in <module>
----> 1 Predictive(model, num_samples=1, batch_ndims=1)(A)
~/anaconda3/envs/numpyro/lib/python3.8/site-packages/numpyro/infer/util.py in __call__(self, rng_key, *args, **kwargs)
585 model_args=args, model_kwargs=kwargs)
586 model = substitute(self.model, self.params)
--> 587 return _predictive(rng_key, model, posterior_samples, self._batch_shape,
588 return_sites=self.return_sites, parallel=self.parallel,
589 model_args=args, model_kwargs=kwargs)
~/anaconda3/envs/numpyro/lib/python3.8/site-packages/numpyro/infer/util.py in _predictive(rng_key, model, posterior_samples, batch_shape, return_sites, parallel, model_args, model_kwargs)
486 if num_samples > 1:
487 rng_key = random.split(rng_key, num_samples)
--> 488 rng_key = rng_key.reshape(batch_shape + (2,))
489 chunk_size = num_samples if parallel else 1
490 return soft_vmap(single_prediction, (rng_key, posterior_samples), len(batch_shape), chunk_size)
~/anaconda3/envs/numpyro/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py in _reshape_method(a, *newshape, **kwargs)
1251 type(newshape[0]) is not Poly):
1252 newshape = newshape[0]
-> 1253 return _reshape(a, newshape, order=order)
1254
1255
~/anaconda3/envs/numpyro/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py in _reshape(a, newshape, order)
1229 computed_newshape = _compute_newshape(a, newshape)
1230 if order == "C":
-> 1231 return lax.reshape(a, computed_newshape, None)
1232 elif order == "F":
1233 dims = np.arange(ndim(a))[::-1]
~/anaconda3/envs/numpyro/lib/python3.8/site-packages/jax/_src/lax/lax.py in reshape(operand, new_sizes, dimensions)
703 return operand
704 else:
--> 705 return reshape_p.bind(
706 operand, new_sizes=new_sizes,
707 dimensions=None if dimensions is None or same_dims else tuple(dimensions))
~/anaconda3/envs/numpyro/lib/python3.8/site-packages/jax/core.py in bind(self, *args, **params)
268 top_trace = find_top_trace(args)
269 tracers = map(top_trace.full_raise, args)
--> 270 out = top_trace.process_primitive(self, tracers, params)
271 return map(full_lower, out) if self.multiple_results else full_lower(out)
272
~/anaconda3/envs/numpyro/lib/python3.8/site-packages/jax/core.py in process_primitive(self, primitive, tracers, params)
578
579 def process_primitive(self, primitive, tracers, params):
--> 580 return primitive.impl(*tracers, **params)
581
582 def process_call(self, primitive, f, tracers, params):
~/anaconda3/envs/numpyro/lib/python3.8/site-packages/jax/_src/lax/lax.py in _reshape_impl(operand, new_sizes, dimensions)
3464 lazy_expr = lazy.broadcast(operand._lazy_expr, new_sizes, bcast_dims)
3465 return xla.make_device_array(aval, operand._device, lazy_expr, operand.device_buffer)
-> 3466 return xla.apply_primitive(reshape_p, operand, new_sizes=new_sizes,
3467 dimensions=dimensions)
3468
~/anaconda3/envs/numpyro/lib/python3.8/site-packages/jax/interpreters/xla.py in apply_primitive(prim, *args, **params)
233 def apply_primitive(prim, *args, **params):
234 """Impl rule that compiles and runs a single primitive 'prim' using XLA."""
--> 235 compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args), **params)
236 return compiled_fun(*args)
237
~/anaconda3/envs/numpyro/lib/python3.8/site-packages/jax/interpreters/xla.py in xla_primitive_callable(prim, *arg_specs, **params)
258 return _xla_callable(lu.wrap_init(prim_fun), device, None, "prim", donated_invars,
259 *arg_specs)
--> 260 aval_out = prim.abstract_eval(*avals, **params)
261 if not prim.multiple_results:
262 handle_result = aval_to_result_handler(device, aval_out)
~/anaconda3/envs/numpyro/lib/python3.8/site-packages/jax/_src/lax/lax.py in standard_abstract_eval(prim, shape_rule, dtype_rule, *args, **kwargs)
1992 out_avals = safe_map(ConcreteArray, out_vals)
1993 elif least_specialized is ShapedArray:
-> 1994 shapes, dtypes = shape_rule(*args, **kwargs), dtype_rule(*args, **kwargs)
1995 if not prim.multiple_results:
1996 shapes, dtypes = [shapes], [dtypes]
~/anaconda3/envs/numpyro/lib/python3.8/site-packages/jax/_src/lax/lax.py in _reshape_shape_rule(operand, new_sizes, dimensions)
3493 if prod(np.shape(operand)) != prod(new_sizes):
3494 msg = 'reshape total size must be unchanged, got new_sizes {} for shape {}.'
-> 3495 raise TypeError(msg.format(new_sizes, np.shape(operand)))
3496 if dimensions is not None:
3497 if set(dimensions) != set(range(np.ndim(operand))):
TypeError: reshape total size must be unchanged, got new_sizes (1, 2) for shape (25, 27).