I wrote a model where a relaxed categorical distribution defines a latent variable. Here is how my model looks like on a high-level:
def model(data):
pyro.module("weight_nn", network)
log_unnormalized_weights = network(data['input'])
log_simplex_weights = make_simplex(log_unnormalized_weights)
with pyro.plate("observed", data):
assignment = pyro.sample("assignment", dist.RelaxedOneHotCategorical(temperature, logits=log_simplex_weights))
...
I wanted to start off the solution with an AutoGuide. Things worked fine for AutoDelta
but when I switched to AutoNormal
, I saw the following error:
...
Traceback (most recent call last):
File "main.py", line 369, in <module>
main()
File "main.py", line 289, in main
loss, seed = min(
File "main.py", line 291, in <genexpr>
initialize(
File "/home/experiments/utils/initialization.py", line 80, in initialize
return svi.loss(model, guide, data)
File "/home/miniconda3/envs/py3-mink/lib/python3.8/site-packages/pyro_ppl-1.8.1+045f9e48-py3.8.egg/pyro/infer/tracegraph_elbo.py", line 276, in loss
for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs):
File "/home/miniconda3/envs/py3-mink/lib/python3.8/site-packages/pyro_ppl-1.8.1+045f9e48-py3.8.egg/pyro/infer/elbo.py", line 178, in _get_traces
self._guess_max_plate_nesting(model, guide, args, kwargs)
File "/home/miniconda3/envs/py3-mink/lib/python3.8/site-packages/pyro_ppl-1.8.1+045f9e48-py3.8.egg/pyro/infer/elbo.py", line 116, in _guess_max_plate_nesting
model_trace.compute_log_prob()
File "/home/miniconda3/envs/py3-mink/lib/python3.8/site-packages/pyro_ppl-1.8.1+045f9e48-py3.8.egg/pyro/poutine/trace_struct.py", line 236, in compute_log_prob
raise ValueError(
File "/home/miniconda3/envs/py3-mink/lib/python3.8/site-packages/pyro_ppl-1.8.1+045f9e48-py3.8.egg/pyro/poutine/trace_struct.py", line 230, in compute_log_prob
log_p = site["fn"].log_prob(
File "/home/miniconda3/envs/py3-mink/lib/python3.8/site-packages/torch/distributions/transformed_distribution.py", line 138, in log_prob
self._validate_sample(value)
File "/home/miniconda3/envs/py3-mink/lib/python3.8/site-packages/torch/distributions/distribution.py", line 277, in _validate_sample
raise ValueError('The value argument must be within the support')
ValueError: Error while computing log_prob at site 'assignment':
The value argument must be within the support
How can I make sure that the value argument must be within the support when I am using an AutoGuide?