Hi, I have the following model where one of the parameters (theta_5
) is normally distributed and constrained to be positive. I’m getting an error related to funsor while running the model. The error stack trace does not mention any particular line in the model code, and I was unable to reproduce the error with a simpler example, hence I’m posting the full model here. Note that this model converges when the constraint theta_5 > 0
is removed, albeit to a negative value of theta_5
, which in this case is unphysical. Also, I have tried to fit this constrained model in Stan and it works for the given data.
The model code.
def dst(theta, t):
return theta[..., 0] + 0.5*theta[..., 1] * (
jnp.tanh(theta[..., 4] * (t - theta[..., 2])) -
jnp.tanh(theta[..., 5] * (t - theta[..., 3]))
)
def my_model(V_obs, t, index_mapping, L, pi, theta_mean, theta_std, s_line_fit_params, h_line_fit_params, s_prior, h_prior, sigma, SL):
with numpyro.plate("L", L):
c = numpyro.sample("c", dist.Categorical(jnp.array(pi)), infer={'enumerate': 'parallel'})
s = numpyro.sample("s", dist.Normal(loc=jnp.array(s_prior[c, 0]), scale=jnp.array(s_prior[c, 1])))
h = numpyro.sample("h", dist.Normal(loc=jnp.array(h_prior[c, 0]), scale=jnp.array(h_prior[c, 1])))
theta_1 = numpyro.sample("theta_1", dist.Normal(loc=jnp.array(theta_mean[c, 0]), scale=jnp.array(theta_std[c, 0])))
theta_2 = numpyro.sample("theta_2", dist.Normal(loc=jnp.array(theta_mean[c, 1]), scale=jnp.array(theta_std[c, 1])))
theta_5 = numpyro.sample("theta_5",
dist.TransformedDistribution(
dist.Normal(loc=jnp.array(theta_mean[c, 2]), scale=jnp.array(theta_std[c, 4])),
transforms.ExpTransform()
))
theta_6 = numpyro.sample("theta_6", dist.Normal(loc=jnp.array(theta_mean[c, 3]), scale=jnp.array(theta_std[c, 5])))
theta_3 = numpyro.sample("theta_3", dist.Normal(loc=jnp.array(s_line_fit_params[c, 0] + s * s_line_fit_params[c, 1]), scale=jnp.array(theta_std[c, 2])))
theta_4 = numpyro.sample("theta_4", dist.Normal(loc=jnp.array(h_line_fit_params[c, 0] + h * h_line_fit_params[c, 1]), scale=jnp.array(theta_std[c, 3])))
theta = numpyro.deterministic("theta", jnp.stack([theta_1, theta_2, theta_3, theta_4, theta_5, theta_6], axis=-1))
with numpyro.plate("SL", SL):
v_t = dst(theta[..., index_mapping, :], t)
V = numpyro.sample("V", dist.Normal(v_t, sigma), obs=V_obs)
The data and the code to run the model
sampler = infer.MCMC(
infer.NUTS(my_model),
num_warmup=500,
num_samples=500,
num_chains=2,
progress_bar=True
)
V_obs = jnp.array(
[0.09903913, 0.11762774, 0.12609756, 0.26895392, 0.40705281,
0.5315631 , 0.6084391 , 0.5900692 , 0.56697017, 0.5216723 ,
0.5225853 , 0.2768902 , 0.20479909, 0.15589964, 0.08958418], dtype='float32'
)
t = jnp.array([ 95, 127, 135, 175, 183, 191, 215, 223, 231, 239, 247, 263, 271, 279, 303], dtype='int32')
index_mapping = jnp.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype='int32')
pi = jnp.array([0.5, 0.5], dtype='float32')
theta_mean = jnp.array(
[[0.106809 , 0.629191 , 0.0809097, 0.0688024],
[0.129963 , 0.767201 , 0.0594144, 0.0990926]]
, dtype='float32'
)
theta_std = jnp.array(
[[0.00858686, 0.048192 , 5.28219 , 7.26483 , 0.0132179 ,
0.0185748 ],
[0.00837342, 0.0493153 , 5.76666 , 7.23596 , 0.0156109 ,
0.0194363 ]]
, dtype='float32'
)
s_line_fit_params = jnp.array(
[[113.968 , 0.483563],
[107.069 , 0.606011]]
, dtype='float32'
)
h_line_fit_params = jnp.array(
[[78.3904 , 0.594259],
[57.3171 , 0.701823]]
, dtype='float32'
)
s_prior = jnp.array(
[[127.08 , 12.664 ],
[140.122 , 15.9329]], dtype='float32'
)
h_prior = jnp.array(
[[299.296 , 15.6646],
[288.097 , 10.5351]], dtype='float32'
)
sigma = 0.0436177
SL = V_obs.size
sampler.run(jrng_key, V_obs, t, index_mapping, L, pi, theta_mean, theta_std, s_line_fit_params, h_line_fit_params, s_prior, h_prior, sigma, SL)
The error stracktrace
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Input In [256], in <cell line: 50>()
47 sigma = 0.0436177
48 SL = V_obs.size
---> 50 sampler.run(jrng_key, V_obs, t, index_mapping, L, pi, theta_mean, theta_std, s_line_fit_params, h_line_fit_params, s_prior, h_prior, sigma, SL)
File ~/cibo/testing/numpyro/numpyro_env/lib/python3.9/site-packages/numpyro/infer/mcmc.py:599, in MCMC.run(self, rng_key, extra_fields, init_params, *args, **kwargs)
597 states, last_state = _laxmap(partial_map_fn, map_args)
598 elif self.chain_method == "parallel":
--> 599 states, last_state = pmap(partial_map_fn)(map_args)
600 else:
601 assert self.chain_method == "vectorized"
[... skipping hidden 17 frame]
File ~/cibo/testing/numpyro/numpyro_env/lib/python3.9/site-packages/numpyro/infer/mcmc.py:381, in MCMC._single_chain_mcmc(self, init, args, kwargs, collect_fields)
379 rng_key, init_state, init_params = init
380 if init_state is None:
--> 381 init_state = self.sampler.init(
382 rng_key,
383 self.num_warmup,
384 init_params,
385 model_args=args,
386 model_kwargs=kwargs,
387 )
388 sample_fn, postprocess_fn = self._get_cached_fns()
389 diagnostics = (
390 lambda x: self.sampler.get_diagnostics_str(x[0])
391 if rng_key.ndim == 1
392 else ""
393 ) # noqa: E731
File ~/cibo/testing/numpyro/numpyro_env/lib/python3.9/site-packages/numpyro/infer/hmc.py:706, in HMC.init(self, rng_key, num_warmup, init_params, model_args, model_kwargs)
701 # vectorized
702 else:
703 rng_key, rng_key_init_model = jnp.swapaxes(
704 vmap(random.split)(rng_key), 0, 1
705 )
--> 706 init_params = self._init_state(
707 rng_key_init_model, model_args, model_kwargs, init_params
708 )
709 if self._potential_fn and init_params is None:
710 raise ValueError(
711 "Valid value of `init_params` must be provided with" " `potential_fn`."
712 )
File ~/cibo/testing/numpyro/numpyro_env/lib/python3.9/site-packages/numpyro/infer/hmc.py:652, in HMC._init_state(self, rng_key, model_args, model_kwargs, init_params)
650 def _init_state(self, rng_key, model_args, model_kwargs, init_params):
651 if self._model is not None:
--> 652 init_params, potential_fn, postprocess_fn, model_trace = initialize_model(
653 rng_key,
654 self._model,
655 dynamic_args=True,
656 init_strategy=self._init_strategy,
657 model_args=model_args,
658 model_kwargs=model_kwargs,
659 forward_mode_differentiation=self._forward_mode_differentiation,
660 )
661 if self._init_fn is None:
662 self._init_fn, self._sample_fn = hmc(
663 potential_fn_gen=potential_fn,
664 kinetic_fn=self._kinetic_fn,
665 algo=self._algo,
666 )
File ~/cibo/testing/numpyro/numpyro_env/lib/python3.9/site-packages/numpyro/infer/util.py:653, in initialize_model(rng_key, model, init_strategy, dynamic_args, model_args, model_kwargs, forward_mode_differentiation, validate_grad)
651 init_strategy = _init_to_unconstrained_value(values=unconstrained_values)
652 prototype_params = transform_fn(inv_transforms, constrained_values, invert=True)
--> 653 (init_params, pe, grad), is_valid = find_valid_initial_params(
654 rng_key,
655 substitute(
656 model,
657 data={
658 k: site["value"]
659 for k, site in model_trace.items()
660 if site["type"] in ["plate"]
661 },
662 ),
663 init_strategy=init_strategy,
664 enum=has_enumerate_support,
665 model_args=model_args,
666 model_kwargs=model_kwargs,
667 prototype_params=prototype_params,
668 forward_mode_differentiation=forward_mode_differentiation,
669 validate_grad=validate_grad,
670 )
672 if not_jax_tracer(is_valid):
673 if device_get(~jnp.all(is_valid)):
File ~/cibo/testing/numpyro/numpyro_env/lib/python3.9/site-packages/numpyro/infer/util.py:394, in find_valid_initial_params(rng_key, model, init_strategy, enum, model_args, model_kwargs, prototype_params, forward_mode_differentiation, validate_grad)
392 # Handle possible vectorization
393 if rng_key.ndim == 1:
--> 394 (init_params, pe, z_grad), is_valid = _find_valid_params(
395 rng_key, exit_early=True
396 )
397 else:
398 (init_params, pe, z_grad), is_valid = lax.map(_find_valid_params, rng_key)
File ~/cibo/testing/numpyro/numpyro_env/lib/python3.9/site-packages/numpyro/infer/util.py:387, in find_valid_initial_params.<locals>._find_valid_params(rng_key, exit_early)
383 return (init_params, pe, z_grad), is_valid
385 # XXX: this requires compiling the model, so for multi-chain, we trace the model 2-times
386 # even if the init_state is a valid result
--> 387 _, _, (init_params, pe, z_grad), is_valid = while_loop(
388 cond_fn, body_fn, init_state
389 )
390 return (init_params, pe, z_grad), is_valid
File ~/cibo/testing/numpyro/numpyro_env/lib/python3.9/site-packages/numpyro/util.py:131, in while_loop(cond_fun, body_fun, init_val)
129 return val
130 else:
--> 131 return lax.while_loop(cond_fun, body_fun, init_val)
[... skipping hidden 11 frame]
File ~/cibo/testing/numpyro/numpyro_env/lib/python3.9/site-packages/numpyro/infer/util.py:364, in find_valid_initial_params.<locals>.body_fn(state)
362 z_grad = jacfwd(potential_fn)(params)
363 else:
--> 364 pe, z_grad = value_and_grad(potential_fn)(params)
365 z_grad_flat = ravel_pytree(z_grad)[0]
366 is_valid = jnp.isfinite(pe) & jnp.all(jnp.isfinite(z_grad_flat))
[... skipping hidden 8 frame]
File ~/cibo/testing/numpyro/numpyro_env/lib/python3.9/site-packages/numpyro/infer/util.py:246, in potential_energy(model, model_args, model_kwargs, params, enum)
242 substituted_model = substitute(
243 model, substitute_fn=partial(_unconstrain_reparam, params)
244 )
245 # no param is needed for log_density computation because we already substitute
--> 246 log_joint, model_trace = log_density_(
247 substituted_model, model_args, model_kwargs, {}
248 )
249 return -log_joint
File ~/cibo/testing/numpyro/numpyro_env/lib/python3.9/site-packages/numpyro/contrib/funsor/infer_util.py:274, in log_density(model, model_args, model_kwargs, params)
253 def log_density(model, model_args, model_kwargs, params):
254 """
255 Similar to :func:`numpyro.infer.util.log_density` but works for models
256 with discrete latent variables. Internally, this uses :mod:`funsor`
(...)
272 :return: log of joint density and a corresponding model trace
273 """
--> 274 result, model_trace, _ = _enum_log_density(
275 model, model_args, model_kwargs, params, funsor.ops.logaddexp, funsor.ops.add
276 )
277 return result.data, model_trace
File ~/cibo/testing/numpyro/numpyro_env/lib/python3.9/site-packages/numpyro/contrib/funsor/infer_util.py:181, in _enum_log_density(model, model_args, model_kwargs, params, sum_op, prod_op)
178 log_prob = scale * log_prob
180 dim_to_name = site["infer"]["dim_to_name"]
--> 181 log_prob_factor = funsor.to_funsor(
182 log_prob, output=funsor.Real, dim_to_name=dim_to_name
183 )
185 time_dim = None
186 for dim, name in dim_to_name.items():
File ~/cibo/testing/numpyro/numpyro_env/lib/python3.9/functools.py:888, in singledispatch.<locals>.wrapper(*args, **kw)
884 if not args:
885 raise TypeError(f'{funcname} requires at least '
886 '1 positional argument')
--> 888 return dispatch(args[0].__class__)(*args, **kw)
File ~/cibo/testing/numpyro/numpyro_env/lib/python3.9/site-packages/funsor/tensor.py:491, in tensor_to_funsor(x, output, dim_to_name)
489 result = Tensor(x, dtype=output.dtype)
490 if result.output != output:
--> 491 raise ValueError(
492 "Invalid shape: expected {}, actual {}".format(
493 output.shape, result.output.shape
494 )
495 )
496 return result
497 else:
ValueError: Invalid shape: expected (), actual (1,)
As per the suggestion in this question, I also tried to run the sampler with numpyro.validation_enabled()
, but it didn’t give any extra information other than the stack trace.
Any help is appreciated.