Regarding the error message for minimize(), i’ll share the error message below. If that won’t do it, I’ll try another minimal example. I’ll just have to add a model function to my previous example in this topic.
Here’s the “short” error that excludes Jax’s internal frames (the longer version was too long):
FilteredStackTrace Traceback (most recent call last)
/var/folders/58/rcr3xzjn6bscyr8llj4t161m0000gn/T/ipykernel_3097/1815523977.py in <module>
1 #Predictive(model, guide, num_samples=2000)
----> 2 guide.quantiles(params, 0.5)
~/Anaconda/anaconda3/lib/python3.7/site-packages/numpyro/infer/autoguide.py in quantiles(self, params, quantiles)
1110 def quantiles(self, params, quantiles):
-> 1111 transform = self.get_transform(params)
1112 quantiles = jnp.array(quantiles)[..., None]
~/Anaconda/anaconda3/lib/python3.7/site-packages/numpyro/infer/autoguide.py in get_transform(self, params)
1082 loc = params["{}_loc".format(self.prefix)]
-> 1083 precision = hessian(loss_fn)(loc)
1084 scale_tril = cholesky_of_inverse(precision)
~/Anaconda/anaconda3/lib/python3.7/site-packages/numpyro/infer/autoguide.py in loss_fn(z)
1079 params1["{}_loc".format(self.prefix)] = z
-> 1080 return self._loss_fn(params1)
1081
~/Anaconda/anaconda3/lib/python3.7/site-packages/numpyro/infer/autoguide.py in loss_fn(params)
1062 return Trace_ELBO().loss(
-> 1063 random.PRNGKey(0), params, self.model, self, *args, **kwargs
1064 )
~/Anaconda/anaconda3/lib/python3.7/site-packages/numpyro/infer/elbo.py in loss(self, rng_key, param_map, model, guide, *args, **kwargs)
53 return self.loss_with_mutable_state(
---> 54 rng_key, param_map, model, guide, *args, **kwargs
55 )["loss"]
~/Anaconda/anaconda3/lib/python3.7/site-packages/numpyro/infer/elbo.py in loss_with_mutable_state(self, rng_key, param_map, model, guide, *args, **kwargs)
149 if self.num_particles == 1:
--> 150 elbo, mutable_state = single_particle_elbo(rng_key)
151 return {"loss": -elbo, "mutable_state": mutable_state}
~/Anaconda/anaconda3/lib/python3.7/site-packages/numpyro/infer/elbo.py in single_particle_elbo(rng_key)
124 model_log_density, model_trace = log_density(
--> 125 seeded_model, args, kwargs, params
126 )
~/Anaconda/anaconda3/lib/python3.7/site-packages/numpyro/infer/util.py in log_density(model, model_args, model_kwargs, params)
52 model = substitute(model, data=params)
---> 53 model_trace = trace(model).get_trace(*model_args, **model_kwargs)
54 log_joint = jnp.zeros(())
~/Anaconda/anaconda3/lib/python3.7/site-packages/numpyro/handlers.py in get_trace(self, *args, **kwargs)
164 """
--> 165 self(*args, **kwargs)
166 return self.trace
~/Anaconda/anaconda3/lib/python3.7/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)
86 with self:
---> 87 return self.fn(*args, **kwargs)
88
~/Anaconda/anaconda3/lib/python3.7/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)
86 with self:
---> 87 return self.fn(*args, **kwargs)
88
~/Anaconda/anaconda3/lib/python3.7/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)
86 with self:
---> 87 return self.fn(*args, **kwargs)
88
~/Anaconda/anaconda3/lib/python3.7/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)
86 with self:
---> 87 return self.fn(*args, **kwargs)
88
/var/folders/58/rcr3xzjn6bscyr8llj4t161m0000gn/T/ipykernel_3097/3862560390.py in model(data, prior_distribution, dir_mat, s, sample_vars, data_bl, aux_predictors_mat, intercept_aux, predictive)
29 intercept_aux, rtol=1e-6, atol=1e-5,
---> 30 mxstep=1000))(data_bl_stocks, data_bl_constants)
31 y_pred = jnp.concatenate(z_array)
/var/folders/58/rcr3xzjn6bscyr8llj4t161m0000gn/T/ipykernel_3097/3862560390.py in <lambda>(y0, constants)
29 intercept_aux, rtol=1e-6, atol=1e-5,
---> 30 mxstep=1000))(data_bl_stocks, data_bl_constants)
31 y_pred = jnp.concatenate(z_array)
~/Anaconda/anaconda3/lib/python3.7/site-packages/jax/experimental/ode.py in odeint(func, y0, t, rtol, atol, mxstep, *args)
172 converted, consts = custom_derivatives.closure_convert(func, y0, t[0], *args)
--> 173 return _odeint_wrapper(converted, rtol, atol, mxstep, y0, t, *args, *consts)
174
~/Anaconda/anaconda3/lib/python3.7/site-packages/jax/experimental/ode.py in _odeint_wrapper(func, rtol, atol, mxstep, y0, ts, *args)
178 func = ravel_first_arg(func, unravel)
--> 179 out = _odeint(func, rtol, atol, mxstep, y0, ts, *args)
180 return jax.vmap(unravel)(out)
~/Anaconda/anaconda3/lib/python3.7/site-packages/jax/experimental/ode.py in _odeint_fwd(func, rtol, atol, mxstep, y0, ts, *args)
217 def _odeint_fwd(func, rtol, atol, mxstep, y0, ts, *args):
--> 218 ys = _odeint(func, rtol, atol, mxstep, y0, ts, *args)
219 return ys, (ys, ts, args)
FilteredStackTrace: TypeError: can't apply forward-mode autodiff (jvp) to a custom_vjp function.