# NumPyro Funsor error when using a positive constraint on a normally distributed variable

Hi, I have the following model where one of the parameters (`theta_5`) is normally distributed and constrained to be positive. I’m getting an error related to funsor while running the model. The error stack trace does not mention any particular line in the model code, and I was unable to reproduce the error with a simpler example, hence I’m posting the full model here. Note that this model converges when the constraint `theta_5 > 0` is removed, albeit to a negative value of `theta_5`, which in this case is unphysical. Also, I have tried to fit this constrained model in Stan and it works for the given data.

The model code.

``````def dst(theta, t):
return theta[..., 0] + 0.5*theta[..., 1] * (
jnp.tanh(theta[..., 4] * (t - theta[..., 2])) -
jnp.tanh(theta[..., 5] * (t - theta[..., 3]))
)

def my_model(V_obs, t, index_mapping, L, pi, theta_mean, theta_std, s_line_fit_params, h_line_fit_params, s_prior, h_prior, sigma, SL):

with numpyro.plate("L", L):
c = numpyro.sample("c", dist.Categorical(jnp.array(pi)), infer={'enumerate': 'parallel'})

s = numpyro.sample("s", dist.Normal(loc=jnp.array(s_prior[c, 0]), scale=jnp.array(s_prior[c, 1])))
h = numpyro.sample("h", dist.Normal(loc=jnp.array(h_prior[c, 0]), scale=jnp.array(h_prior[c, 1])))

theta_1 = numpyro.sample("theta_1", dist.Normal(loc=jnp.array(theta_mean[c, 0]), scale=jnp.array(theta_std[c, 0])))
theta_2 = numpyro.sample("theta_2", dist.Normal(loc=jnp.array(theta_mean[c, 1]), scale=jnp.array(theta_std[c, 1])))

theta_5 = numpyro.sample("theta_5",
dist.TransformedDistribution(
dist.Normal(loc=jnp.array(theta_mean[c, 2]), scale=jnp.array(theta_std[c, 4])),
transforms.ExpTransform()
))

theta_6 = numpyro.sample("theta_6", dist.Normal(loc=jnp.array(theta_mean[c, 3]), scale=jnp.array(theta_std[c, 5])))

theta_3 = numpyro.sample("theta_3", dist.Normal(loc=jnp.array(s_line_fit_params[c, 0] + s * s_line_fit_params[c, 1]), scale=jnp.array(theta_std[c, 2])))
theta_4 = numpyro.sample("theta_4", dist.Normal(loc=jnp.array(h_line_fit_params[c, 0] + h * h_line_fit_params[c, 1]), scale=jnp.array(theta_std[c, 3])))

theta = numpyro.deterministic("theta", jnp.stack([theta_1, theta_2, theta_3, theta_4, theta_5, theta_6], axis=-1))

with numpyro.plate("SL", SL):
v_t = dst(theta[..., index_mapping, :], t)
V = numpyro.sample("V", dist.Normal(v_t, sigma), obs=V_obs)

``````

The data and the code to run the model

``````sampler = infer.MCMC(
infer.NUTS(my_model),
num_warmup=500,
num_samples=500,
num_chains=2,
progress_bar=True
)

V_obs = jnp.array(
[0.09903913, 0.11762774, 0.12609756, 0.26895392, 0.40705281,
0.5315631 , 0.6084391 , 0.5900692 , 0.56697017, 0.5216723 ,
0.5225853 , 0.2768902 , 0.20479909, 0.15589964, 0.08958418], dtype='float32'
)
t = jnp.array([ 95, 127, 135, 175, 183, 191, 215, 223, 231, 239, 247, 263, 271, 279, 303], dtype='int32')
index_mapping = jnp.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype='int32')
pi = jnp.array([0.5, 0.5], dtype='float32')
theta_mean = jnp.array(
[[0.106809 , 0.629191 , 0.0809097, 0.0688024],
[0.129963 , 0.767201 , 0.0594144, 0.0990926]]
, dtype='float32'
)
theta_std = jnp.array(
[[0.00858686, 0.048192  , 5.28219   , 7.26483   , 0.0132179 ,
0.0185748 ],
[0.00837342, 0.0493153 , 5.76666   , 7.23596   , 0.0156109 ,
0.0194363 ]]
, dtype='float32'
)
s_line_fit_params = jnp.array(
[[113.968   ,   0.483563],
[107.069   ,   0.606011]]
, dtype='float32'
)
h_line_fit_params = jnp.array(
[[78.3904  ,  0.594259],
[57.3171  ,  0.701823]]
, dtype='float32'
)
s_prior = jnp.array(
[[127.08  ,  12.664 ],
[140.122 ,  15.9329]], dtype='float32'
)
h_prior = jnp.array(
[[299.296 ,  15.6646],
[288.097 ,  10.5351]], dtype='float32'
)
sigma = 0.0436177
SL = V_obs.size

sampler.run(jrng_key, V_obs, t, index_mapping, L, pi, theta_mean, theta_std, s_line_fit_params, h_line_fit_params, s_prior, h_prior, sigma, SL)
``````

The error stracktrace

``````---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Input In [256], in <cell line: 50>()
47 sigma = 0.0436177
48 SL = V_obs.size
---> 50 sampler.run(jrng_key, V_obs, t, index_mapping, L, pi, theta_mean, theta_std, s_line_fit_params, h_line_fit_params, s_prior, h_prior, sigma, SL)

File ~/cibo/testing/numpyro/numpyro_env/lib/python3.9/site-packages/numpyro/infer/mcmc.py:599, in MCMC.run(self, rng_key, extra_fields, init_params, *args, **kwargs)
597     states, last_state = _laxmap(partial_map_fn, map_args)
598 elif self.chain_method == "parallel":
--> 599     states, last_state = pmap(partial_map_fn)(map_args)
600 else:
601     assert self.chain_method == "vectorized"

[... skipping hidden 17 frame]

File ~/cibo/testing/numpyro/numpyro_env/lib/python3.9/site-packages/numpyro/infer/mcmc.py:381, in MCMC._single_chain_mcmc(self, init, args, kwargs, collect_fields)
379 rng_key, init_state, init_params = init
380 if init_state is None:
--> 381     init_state = self.sampler.init(
382         rng_key,
383         self.num_warmup,
384         init_params,
385         model_args=args,
386         model_kwargs=kwargs,
387     )
388 sample_fn, postprocess_fn = self._get_cached_fns()
389 diagnostics = (
390     lambda x: self.sampler.get_diagnostics_str(x[0])
391     if rng_key.ndim == 1
392     else ""
393 )  # noqa: E731

File ~/cibo/testing/numpyro/numpyro_env/lib/python3.9/site-packages/numpyro/infer/hmc.py:706, in HMC.init(self, rng_key, num_warmup, init_params, model_args, model_kwargs)
701 # vectorized
702 else:
703     rng_key, rng_key_init_model = jnp.swapaxes(
704         vmap(random.split)(rng_key), 0, 1
705     )
--> 706 init_params = self._init_state(
707     rng_key_init_model, model_args, model_kwargs, init_params
708 )
709 if self._potential_fn and init_params is None:
710     raise ValueError(
711         "Valid value of `init_params` must be provided with" " `potential_fn`."
712     )

File ~/cibo/testing/numpyro/numpyro_env/lib/python3.9/site-packages/numpyro/infer/hmc.py:652, in HMC._init_state(self, rng_key, model_args, model_kwargs, init_params)
650 def _init_state(self, rng_key, model_args, model_kwargs, init_params):
651     if self._model is not None:
--> 652         init_params, potential_fn, postprocess_fn, model_trace = initialize_model(
653             rng_key,
654             self._model,
655             dynamic_args=True,
656             init_strategy=self._init_strategy,
657             model_args=model_args,
658             model_kwargs=model_kwargs,
659             forward_mode_differentiation=self._forward_mode_differentiation,
660         )
661         if self._init_fn is None:
662             self._init_fn, self._sample_fn = hmc(
663                 potential_fn_gen=potential_fn,
664                 kinetic_fn=self._kinetic_fn,
665                 algo=self._algo,
666             )

File ~/cibo/testing/numpyro/numpyro_env/lib/python3.9/site-packages/numpyro/infer/util.py:653, in initialize_model(rng_key, model, init_strategy, dynamic_args, model_args, model_kwargs, forward_mode_differentiation, validate_grad)
651     init_strategy = _init_to_unconstrained_value(values=unconstrained_values)
652 prototype_params = transform_fn(inv_transforms, constrained_values, invert=True)
--> 653 (init_params, pe, grad), is_valid = find_valid_initial_params(
654     rng_key,
655     substitute(
656         model,
657         data={
658             k: site["value"]
659             for k, site in model_trace.items()
660             if site["type"] in ["plate"]
661         },
662     ),
663     init_strategy=init_strategy,
664     enum=has_enumerate_support,
665     model_args=model_args,
666     model_kwargs=model_kwargs,
667     prototype_params=prototype_params,
668     forward_mode_differentiation=forward_mode_differentiation,
670 )
672 if not_jax_tracer(is_valid):
673     if device_get(~jnp.all(is_valid)):

File ~/cibo/testing/numpyro/numpyro_env/lib/python3.9/site-packages/numpyro/infer/util.py:394, in find_valid_initial_params(rng_key, model, init_strategy, enum, model_args, model_kwargs, prototype_params, forward_mode_differentiation, validate_grad)
392 # Handle possible vectorization
393 if rng_key.ndim == 1:
--> 394     (init_params, pe, z_grad), is_valid = _find_valid_params(
395         rng_key, exit_early=True
396     )
397 else:
398     (init_params, pe, z_grad), is_valid = lax.map(_find_valid_params, rng_key)

File ~/cibo/testing/numpyro/numpyro_env/lib/python3.9/site-packages/numpyro/infer/util.py:387, in find_valid_initial_params.<locals>._find_valid_params(rng_key, exit_early)
383             return (init_params, pe, z_grad), is_valid
385 # XXX: this requires compiling the model, so for multi-chain, we trace the model 2-times
386 # even if the init_state is a valid result
--> 387 _, _, (init_params, pe, z_grad), is_valid = while_loop(
388     cond_fn, body_fn, init_state
389 )
390 return (init_params, pe, z_grad), is_valid

File ~/cibo/testing/numpyro/numpyro_env/lib/python3.9/site-packages/numpyro/util.py:131, in while_loop(cond_fun, body_fun, init_val)
129     return val
130 else:
--> 131     return lax.while_loop(cond_fun, body_fun, init_val)

[... skipping hidden 11 frame]

File ~/cibo/testing/numpyro/numpyro_env/lib/python3.9/site-packages/numpyro/infer/util.py:364, in find_valid_initial_params.<locals>.body_fn(state)
363 else:
366 is_valid = jnp.isfinite(pe) & jnp.all(jnp.isfinite(z_grad_flat))

[... skipping hidden 8 frame]

File ~/cibo/testing/numpyro/numpyro_env/lib/python3.9/site-packages/numpyro/infer/util.py:246, in potential_energy(model, model_args, model_kwargs, params, enum)
242 substituted_model = substitute(
243     model, substitute_fn=partial(_unconstrain_reparam, params)
244 )
245 # no param is needed for log_density computation because we already substitute
--> 246 log_joint, model_trace = log_density_(
247     substituted_model, model_args, model_kwargs, {}
248 )
249 return -log_joint

File ~/cibo/testing/numpyro/numpyro_env/lib/python3.9/site-packages/numpyro/contrib/funsor/infer_util.py:274, in log_density(model, model_args, model_kwargs, params)
253 def log_density(model, model_args, model_kwargs, params):
254     """
255     Similar to :func:`numpyro.infer.util.log_density` but works for models
256     with discrete latent variables. Internally, this uses :mod:`funsor`
(...)
272     :return: log of joint density and a corresponding model trace
273     """
--> 274     result, model_trace, _ = _enum_log_density(
276     )
277     return result.data, model_trace

File ~/cibo/testing/numpyro/numpyro_env/lib/python3.9/site-packages/numpyro/contrib/funsor/infer_util.py:181, in _enum_log_density(model, model_args, model_kwargs, params, sum_op, prod_op)
178     log_prob = scale * log_prob
180 dim_to_name = site["infer"]["dim_to_name"]
--> 181 log_prob_factor = funsor.to_funsor(
182     log_prob, output=funsor.Real, dim_to_name=dim_to_name
183 )
185 time_dim = None
186 for dim, name in dim_to_name.items():

File ~/cibo/testing/numpyro/numpyro_env/lib/python3.9/functools.py:888, in singledispatch.<locals>.wrapper(*args, **kw)
884 if not args:
885     raise TypeError(f'{funcname} requires at least '
886                     '1 positional argument')
--> 888 return dispatch(args[0].__class__)(*args, **kw)

File ~/cibo/testing/numpyro/numpyro_env/lib/python3.9/site-packages/funsor/tensor.py:491, in tensor_to_funsor(x, output, dim_to_name)
489     result = Tensor(x, dtype=output.dtype)
490     if result.output != output:
--> 491         raise ValueError(
492             "Invalid shape: expected {}, actual {}".format(
493                 output.shape, result.output.shape
494             )
495         )
496     return result
497 else:

ValueError: Invalid shape: expected (), actual (1,)
``````

As per the suggestion in this question, I also tried to run the sampler with `numpyro.validation_enabled()`, but it didn’t give any extra information other than the stack trace.

Any help is appreciated.

Hi @pankajb64 Could you simplify the model to just include `c` and `theta_5` (i.e., removing all variables not related to the error, removing `dst`…)? One possible solution is to draw theta_5_raw from the normal distribution and then set `theta_5 = jnp.exp(theta_5_raw)`.

Just curious, does Stan support enumeration?

Thanks. I can confirm that this (simplified) model doesn’t work

``````
def my_model(L, pi, theta_mean, theta_std):

with numpyro.plate("L", L):
c = numpyro.sample("c", dist.Categorical(jnp.array(pi)), infer={'enumerate': 'parallel'})

theta_5 = numpyro.sample("theta_5",
dist.TransformedDistribution(
dist.Normal(loc=jnp.array(theta_mean[c, 2]), scale=jnp.array(theta_std[c, 4])),
transforms.ExpTransform()
))
``````

but doing the `jnp.exp` change works

``````
def crop_inference_model(L, pi, theta_mean, theta_std):

with numpyro.plate("L", L):
c = numpyro.sample("c", dist.Categorical(jnp.array(pi)), infer={'enumerate': 'parallel'})

theta_5_raw = numpyro.sample("theta_5_raw", dist.Normal(loc=jnp.array(theta_mean[c, 2]), scale=jnp.array(theta_std[c, 4])))
theta_5 = numpyro.deterministic("theta_5", jnp.exp(theta_5_raw))
``````

Stan does not support enumeration out of the box I think. I wrote code to marginalize `c` (using `log-sum-exp`).

Also want to point that the same error occurs when using `OrderedTransform`. So for instance the following code does not work

``````def my_model(L, pi, s_line_fit_params, h_line_fit_params, s_prior, h_prior):

with numpyro.plate("L", L):
c = numpyro.sample("c", dist.Categorical(jnp.array(pi)), infer={'enumerate': 'parallel'})

s = numpyro.sample("s", dist.Normal(loc=jnp.array(s_prior[c, 0]), scale=jnp.array(s_prior[c, 1])))
h = numpyro.sample("h", dist.Normal(loc=jnp.array(h_prior[c, 0]), scale=jnp.array(h_prior[c, 1])))

gamma_1 = jnp.array(s_line_fit_params[c, 0] + s * s_line_fit_params[c, 1])
gamma_2 = jnp.array(h_line_fit_params[c, 0] + h * h_line_fit_params[c, 1])

theta_mean = jnp.stack([gamma_1, gamma_2], axis=-1)
theta_std = jnp.stack([sigma_gamma[c, 2], sigma_gamma[c, 3]], axis=-1)

theta_ordered = numpyro.sample("theta_34",
dist.TransformedDistribution(
dist.Normal(loc=jnp.array(theta_mean), scale=jnp.array(theta_std)),
transforms.OrderedTransform()
))

``````

But the code below (which I think does the equivalent of what `OrderedTransform` does under the hood) works

``````def my_model(L, pi, s_line_fit_params, h_line_fit_params, s_prior, h_prior):

with numpyro.plate("L", L):
c = numpyro.sample("c", dist.Categorical(jnp.array(pi)), infer={'enumerate': 'parallel'})

s = numpyro.sample("s", dist.Normal(loc=jnp.array(s_prior[c, 0]), scale=jnp.array(s_prior[c, 1])))
h = numpyro.sample("h", dist.Normal(loc=jnp.array(h_prior[c, 0]), scale=jnp.array(h_prior[c, 1])))

gamma_1 = jnp.array(s_line_fit_params[c, 0] + s * s_line_fit_params[c, 1])
gamma_2 = jnp.array(h_line_fit_params[c, 0] + h * h_line_fit_params[c, 1])

theta_mean = jnp.stack([gamma_1, gamma_2], axis=-1)
theta_std = jnp.stack([sigma_gamma[c, 2], sigma_gamma[c, 3]], axis=-1)

theta_1 = numpyro.sample("theta_1", dist.Normal(loc=theta_mean[..., 0], scale=theta_std[..., 0]))

theta_2_raw = numpyro.sample("theta_2_raw", dist.Normal(loc=theta_mean[..., 1], scale=theta_std[..., 1]))
theta_2 = numpyro.deterministic("theta_2", theta_1 + jnp.exp(theta_2_raw))

``````

Although the latter code does not give an error and converges, it produces values that are very unexpected, even after giving it good initial values (for the variables `s` and `h`). This model has been working fine in Stan (with the ordered constraint and the positive constraint). I’m hoping to get this to work in numpyro to get speed benefits, but I’m struggling. Any help is appreciated.

Looking like we haven’t supported transformed distribution with non trivial event dimensions yet (see this todo), please open a feature request for it.

Created an issue #597 in funsor. I also tried using `TruncatedNormal` with `low=0` instead of `TransformedDistribution` with `ExpTransform`, but I get the same error: `ValueError: Invalid shape: expected (), actual (1,)`. I have added that info the feature request as well.

Can you suggest a workaround in the meantime? As I mentioned, the sampler converges but it doesn’t produce the right values. I don’t think what I’m doing above is right.

For instance, in the model in this comment, I am setting `theta_5_raw ~ normal(mu, std)` and doing `theta_5 = jnp.exp(theta_5_raw)`, whereas I actually want `theta_5` to be normally distributed with the same `mu` and `std` parameters.

Same is the case with the orderd constraint for the model in this comment.

Maybe I need to use `numpyro.factor()` here to update the log pdf, ala updating the `target` variable in Stan?

The transformed distribution with exp transform should have the same meaning as drawing theta_raw and then transforming. I’m not sure if factor works (probably yes).

If what you wanted is a TruncatedNormal distribution (rather than a LogNormal distribution), you can define it and register the class like other distributions. I’ll take a look in the weekend.

``````# you'll need to define your own TruncatedNormal with `loc`, `scale` arguments
# where you can reuse `dist.TruncatedNormal(loc, scale).sample` and `log_prob` methods.
funsor.distributions.make_dist(TruncatedNormal, param_names=("loc", "scale"))
``````

I implemented the following, based on how Stan Manual describes they implement transformation of a lower-bounded variable. When I use it as a distribution with `funsor.distribution.make_dist`, I still get the error: `ValueError: Invalid shape: expected (), actual (1,)`

``````class PositiveTransformedNormalDist(Distribution):
support = constraints.positive
def __init__(self, loc, scale):

self.loc = loc
self.scale = scale
self._dist = dist.Normal(loc, scale)
super().__init__(batch_shape=batch_shape)

#From implementation of LeftTruncatedDistribution in
#https://num.pyro.ai/en/stable/_modules/numpyro/distributions/truncated.html#LeftTruncatedDistribution
def sample(self, key, sample_shape=()):
assert is_prng_key(key)
dtype = jnp.result_type(float)
finfo = jnp.finfo(dtype)
minval = finfo.tiny
u = jax.random.uniform(key, shape=sample_shape + self.batch_shape, minval=minval)

return jnp.exp(self._dist.icdf(u))

def log_prob(self, x):
ex = jnp.exp(x)
return self._dist.log_prob(ex) + x
``````

This implementation could be wrong, especially the `sample` function, I calculated the `ICDF` by math, and I could’ve made a mistake. But the point is, the error persists.

I looked into this a little bit more, and I don’t think using the transformation I’m using does what I want it to.

I want to declare a variable to be sampled from a normal distribution with the constraint that it is lower bounded by 0. So something like this in Stan

``````data{
real mu;
real sigma;
}
parameters{
real<lower=0> theta;
}
model{
theta ~ normal(mu, sigma);
}
``````

Can you advise how can I implement this in numpyro? I can confirm that the following code IS NOT equivalent to the above Stan code. The prior predictive distribution looks different

``````def model(mu, sigma):
theta_raw = numpyro.sample("theta_raw", dist.Normal(mu, sigma))
theta = jnp.exp(theta_raw)
``````

And from what you said in this thread, the above numpyro model is equivalent to the below numpyro model, so this must also not be what I want.

``````def model(mu, sigma):
theta = numpyro.sample("theta", dist.TransformedDistribution(dist.Normal(mu, sigma), transforms.ExpTransform()))
``````

That Stan statement corresponds to an improper truncated distribution. To complete my previous suggestion, here is the corresponding code

``````import funsor
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt

import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS

class TruncatedNormal(dist.Normal):
support = dist.constraints.positive
def sample(self, key, sample_shape=()):
return dist.TruncatedNormal(self.loc, self.scale, low=0.).sample(key, sample_shape=sample_shape)
def log_prob(self, value):
return dist.TruncatedNormal(self.loc, self.scale, low=0.).log_prob(value)

funsor.distribution.make_dist(TruncatedNormal, param_names=("loc", "scale"))

def model():
numpyro.sample("c", dist.Bernoulli(0.5), infer={"enumerate": "parallel"})
numpyro.sample("x", TruncatedNormal(2, 3))

mcmc = MCMC(NUTS(model), num_warmup=100, num_samples=10000)
mcmc.run(jax.random.PRNGKey(0))
x = mcmc.get_samples()["x"]
plt.hist(x);
``````

You can also define an improper distribution if you want. MCMC will give similar results

``````class ImproperTruncatedNormal(dist.Normal):
support = dist.constraints.positive

funsor.distribution.make_dist(ImproperTruncatedNormal, param_names=("loc", "scale"))

def model():
numpyro.sample("c", dist.Bernoulli(0.5), infer={"enumerate": "parallel"})
numpyro.sample("x", ImproperTruncatedNormal(2, 3))
``````

Thanks @martinjankowiak and @fehiepsi . I tried using the `ImproperTruncatedNormal` distribution class as shared by @fehiepsi , but I’m still getting the same error as in the first message of this thread: `ValueError: Invalid shape: expected (), actual (1,)`

Am I right in concluding that there is no way around this except waiting for issue #597 to be implemented?

Ok, with a bit of luck, I’ve figured out the issue I think.

The above error happens when the length of the plate variable (`L` in my case) is 1; when `L > 1` the error doesn’t happen.

Its strange because when the code works with `L = 1` when you’re not using `TransformedDistribution`, `TruncatedNormal` or `ImproperTruncatedNormal`, it also works when I was doing the transformation myself, e.g.

``````theta_raw = dist.Normal(...)
theta = jnp.exp(theta_raw)
``````

Not sure what’s up with that.

In either case, using `L > 1` seems to make things work, and the sampled values are reasonable (at least on the data I’m currently using to test). I will test more and report if something is wrong again.

Looking like you are facing this issue initialize_model() fails with batch size of 1 for model with discrete variables · Issue #1448 · pyro-ppl/numpyro · GitHub

1 Like