Question about a strange error when using block in numpyro

In my model, i want to block a sample site to prevent it from being infered by the svi, the following code is how i do it.
But the code gives me assertion error with no description, please help me, am i using block wrong?

from jax.random import PRNGKey
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.infer import SVI,Trace_ELBO
from numpyro.infer.autoguide import AutoDelta

def model():
    with numpyro.handlers.block():
        a = numpyro.sample('a', dist.Normal(0,1))
    b = numpyro.sample('b', dist.Dirichlet(concentration=jnp.array([2., 3, 4, 5, 6])))

optim = numpyro.optim.Adam(step_size=1e-3)
elbo = Trace_ELBO()
guide = AutoDelta(model)
svi = SVI(model, guide, optim, elbo)
svi_result = svi.run(PRNGKey(0), 1000)

The error code is

AssertionError
exception: no description
  File "/home/yw/Downloads/numpyro/numpyro/distributions/continuous.py", line 1706, in sample
    assert is_prng_key(key)
  File "/home/yw/Downloads/numpyro/numpyro/distributions/distribution.py", line 262, in sample_with_intermediates
    return self.sample(key, sample_shape=sample_shape), []
  File "/home/yw/Downloads/numpyro/numpyro/distributions/distribution.py", line 304, in __call__
    return self.sample_with_intermediates(key, *args, **kwargs)
  File "/home/yw/Downloads/numpyro/numpyro/primitives.py", line 24, in default_process_message
    msg["value"], msg["intermediates"] = msg["fn"](
  File "/home/yw/Downloads/numpyro/numpyro/primitives.py", line 53, in apply_stack
    default_process_message(msg)
  File "/home/yw/Downloads/numpyro/numpyro/primitives.py", line 222, in sample
    msg = apply_stack(initial_msg)
  File "/home/yw/Documents/tmm/test.py", line 13, in generative_model
    a = numpyro.sample('a', dist.Normal(0,1))
  File "/home/yw/Downloads/numpyro/numpyro/primitives.py", line 105, in __call__
    return self.fn(*args, **kwargs)
  File "/home/yw/Downloads/numpyro/numpyro/primitives.py", line 105, in __call__
    return self.fn(*args, **kwargs)
  File "/home/yw/Downloads/numpyro/numpyro/primitives.py", line 105, in __call__
    return self.fn(*args, **kwargs)
  File "/home/yw/Downloads/numpyro/numpyro/handlers.py", line 171, in get_trace
    self(*args, **kwargs)
  File "/home/yw/Downloads/numpyro/numpyro/infer/util.py", line 404, in _get_model_transforms
    model_trace = trace(model).get_trace(*model_args, **model_kwargs)
  File "/home/yw/Downloads/numpyro/numpyro/infer/util.py", line 606, in initialize_model
    ) = _get_model_transforms(substituted_model, model_args, model_kwargs)
  File "/home/yw/Downloads/numpyro/numpyro/infer/autoguide.py", line 156, in _setup_prototype
    ) = initialize_model(
  File "/home/yw/Downloads/numpyro/numpyro/infer/autoguide.py", line 406, in _setup_prototype
    super()._setup_prototype(*args, **kwargs)
  File "/home/yw/Downloads/numpyro/numpyro/infer/autoguide.py", line 432, in __call__
    self._setup_prototype(*args, **kwargs)
  File "/home/yw/Downloads/numpyro/numpyro/primitives.py", line 105, in __call__
    return self.fn(*args, **kwargs)
  File "/home/yw/Downloads/numpyro/numpyro/primitives.py", line 105, in __call__
    return self.fn(*args, **kwargs)
  File "/home/yw/Downloads/numpyro/numpyro/handlers.py", line 171, in get_trace
    self(*args, **kwargs)
  File "/home/yw/Downloads/numpyro/numpyro/infer/svi.py", line 180, in init
    guide_trace = trace(guide_init).get_trace(*args, **kwargs, **self.static_kwargs)
  File "/home/yw/Downloads/numpyro/numpyro/infer/svi.py", line 342, in run
    svi_state = self.init(rng_key, *args, **kwargs)
  File "/home/yw/Documents/tmm/test.py", line 20, in <module>
    svi_result = svi.run(PRNGKey(0), 1000)
1 Like

Hi! @fehiepsi @ordabayev
I find the cause of this problem. It seems that block handler hides the sample sites so SVI can’t pass it’s RNGKey to the hided sample sites. I’m wondering is this a designed property or could it be improved?
In my opinion block shouldn’t block sample sites from out world RNGkeys, because RNGkeys is not primitive concepts in probabilistic theory.

I think you can use exposed_types=["prng_key"] in block handler. Alternatively, you can do

key = numpyro.rng_key()
with block(), handlers.seed(rng_seed=key):
    sub_program()

The later approach makes it clear that the main program has deterministic keys regardless the number of sample statements in the blocked program.

Thanks for the reply, really appreciate!.

Buy the way, i have a question about rng_seed in SVI.
If all the svi runs use the same rngkey, say, RNGKey(0), then all sample sites in the model and guide should sample the same value, as i guess, which is not the true case obviously, but how do svi deals with this fixed random number problem?

Each svi step will use a new key, split from the key from the previous step, like

next_key, model_key, guide_key = random.split(current_key, 3)
# use model key, guide key to calculate the objective
# then use next_key in the next step

Thanks for the helpful reply!