I have a model – below joint_model
– that I’m able to run with NUTS. Now I want to try using variational inference to get (hopefully) a faster approximate answer. But I’m running into this issue
>>> guide = AutoNormal(joint_model)
... svi = SVI(joint_model, guide, optimizer, loss=Trace_ELBO())
... init_state = svi.init(random.key(0))
...
Traceback (most recent call last):
File "<python-input-20>", line 3, in <module>
init_state = svi.init(random.key(0))
File "[...]/.venv/lib/python3.13/site-packages/numpyro/infer/svi.py", line 184, in init
guide_trace = trace(guide_init).get_trace(*args, **kwargs, **self.static_kwargs)
File "[...]/.venv/lib/python3.13/site-packages/numpyro/handlers.py", line 191, in get_trace
self(*args, **kwargs)
~~~~^^^^^^^^^^^^^^^^^
File "[...]/.venv/lib/python3.13/site-packages/numpyro/primitives.py", line 121, in __call__
return self.fn(*args, **kwargs)
~~~~~~~^^^^^^^^^^^^^^^^^
File "[...]/.venv/lib/python3.13/site-packages/numpyro/handlers.py", line 846, in __call__
return cloned_seeded_fn.__call__(*args, **kwargs)
~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
File "[...]/.venv/lib/python3.13/site-packages/numpyro/handlers.py", line 847, in __call__
return super().__call__(*args, **kwargs)
~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
File "[...]/.venv/lib/python3.13/site-packages/numpyro/primitives.py", line 121, in __call__
return self.fn(*args, **kwargs)
~~~~~~~^^^^^^^^^^^^^^^^^
File "[...]/.venv/lib/python3.13/site-packages/numpyro/infer/autoguide.py", line 411, in __call__
self._setup_prototype(*args, **kwargs)
~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
File "[...]/.venv/lib/python3.13/site-packages/numpyro/infer/autoguide.py", line 401, in _setup_prototype
full_size = self._prototype_frame_full_sizes[frame.name]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^
KeyError: 'std_normals_plate'
Now my joint_model
is a composition of several modules. The one that has the std_normals_plate
is this
def CAR_layer_units(base_shape, adj_matrix, rho_prior = dist.Beta(1000.,.9)):
rho = numpyro.sample('rho', rho_prior)
with numpyro.plate("std_normals_plate", adj_matrix.shape[0], dim=-2):
zs = numpyro.sample('units', dist.Normal())
# transform iid normals -> correlated normals -> correlated U[0,1] variables
us = transform_iid_normal_layers(base_shape, adj_matrix, rho, zs)
return us
The weird thing is that this succeeds
>>> guide = AutoNormal(CAR_layer_units)
... svi = SVI(CAR_layer_units, guide, optimizer, loss=Trace_ELBO())
... init_state = svi.init(random.key(0), base_shape, adj_matrix)
...
>>>
Sadly it’s going to be impossible to share the whole model here. I’m hoping someone can give me a hint of what may be going wrong here just by looking at the error trace.
Cheers