In my model, i want to block a sample site to prevent it from being infered by the svi, the following code is how i do it.
But the code gives me assertion error with no description, please help me, am i using block wrong?
from jax.random import PRNGKey
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.infer import SVI,Trace_ELBO
from numpyro.infer.autoguide import AutoDelta
def model():
with numpyro.handlers.block():
a = numpyro.sample('a', dist.Normal(0,1))
b = numpyro.sample('b', dist.Dirichlet(concentration=jnp.array([2., 3, 4, 5, 6])))
optim = numpyro.optim.Adam(step_size=1e-3)
elbo = Trace_ELBO()
guide = AutoDelta(model)
svi = SVI(model, guide, optim, elbo)
svi_result = svi.run(PRNGKey(0), 1000)
The error code is
AssertionError
exception: no description
File "/home/yw/Downloads/numpyro/numpyro/distributions/continuous.py", line 1706, in sample
assert is_prng_key(key)
File "/home/yw/Downloads/numpyro/numpyro/distributions/distribution.py", line 262, in sample_with_intermediates
return self.sample(key, sample_shape=sample_shape), []
File "/home/yw/Downloads/numpyro/numpyro/distributions/distribution.py", line 304, in __call__
return self.sample_with_intermediates(key, *args, **kwargs)
File "/home/yw/Downloads/numpyro/numpyro/primitives.py", line 24, in default_process_message
msg["value"], msg["intermediates"] = msg["fn"](
File "/home/yw/Downloads/numpyro/numpyro/primitives.py", line 53, in apply_stack
default_process_message(msg)
File "/home/yw/Downloads/numpyro/numpyro/primitives.py", line 222, in sample
msg = apply_stack(initial_msg)
File "/home/yw/Documents/tmm/test.py", line 13, in generative_model
a = numpyro.sample('a', dist.Normal(0,1))
File "/home/yw/Downloads/numpyro/numpyro/primitives.py", line 105, in __call__
return self.fn(*args, **kwargs)
File "/home/yw/Downloads/numpyro/numpyro/primitives.py", line 105, in __call__
return self.fn(*args, **kwargs)
File "/home/yw/Downloads/numpyro/numpyro/primitives.py", line 105, in __call__
return self.fn(*args, **kwargs)
File "/home/yw/Downloads/numpyro/numpyro/handlers.py", line 171, in get_trace
self(*args, **kwargs)
File "/home/yw/Downloads/numpyro/numpyro/infer/util.py", line 404, in _get_model_transforms
model_trace = trace(model).get_trace(*model_args, **model_kwargs)
File "/home/yw/Downloads/numpyro/numpyro/infer/util.py", line 606, in initialize_model
) = _get_model_transforms(substituted_model, model_args, model_kwargs)
File "/home/yw/Downloads/numpyro/numpyro/infer/autoguide.py", line 156, in _setup_prototype
) = initialize_model(
File "/home/yw/Downloads/numpyro/numpyro/infer/autoguide.py", line 406, in _setup_prototype
super()._setup_prototype(*args, **kwargs)
File "/home/yw/Downloads/numpyro/numpyro/infer/autoguide.py", line 432, in __call__
self._setup_prototype(*args, **kwargs)
File "/home/yw/Downloads/numpyro/numpyro/primitives.py", line 105, in __call__
return self.fn(*args, **kwargs)
File "/home/yw/Downloads/numpyro/numpyro/primitives.py", line 105, in __call__
return self.fn(*args, **kwargs)
File "/home/yw/Downloads/numpyro/numpyro/handlers.py", line 171, in get_trace
self(*args, **kwargs)
File "/home/yw/Downloads/numpyro/numpyro/infer/svi.py", line 180, in init
guide_trace = trace(guide_init).get_trace(*args, **kwargs, **self.static_kwargs)
File "/home/yw/Downloads/numpyro/numpyro/infer/svi.py", line 342, in run
svi_state = self.init(rng_key, *args, **kwargs)
File "/home/yw/Documents/tmm/test.py", line 20, in <module>
svi_result = svi.run(PRNGKey(0), 1000)