"The value argument must be within the support" error when using an AutoGuide

I wrote a model where a relaxed categorical distribution defines a latent variable. Here is how my model looks like on a high-level:

def model(data):
    pyro.module("weight_nn", network)
    log_unnormalized_weights = network(data['input'])
    log_simplex_weights = make_simplex(log_unnormalized_weights)
    with pyro.plate("observed", data):
        assignment = pyro.sample("assignment", dist.RelaxedOneHotCategorical(temperature, logits=log_simplex_weights))
        ...

I wanted to start off the solution with an AutoGuide. Things worked fine for AutoDelta but when I switched to AutoNormal, I saw the following error:

...
Traceback (most recent call last):                                                                                                                                                                                    
  File "main.py", line 369, in <module>                                                                                                                                                   
    main()                                                                                                                                                                                                            
  File "main.py", line 289, in main                                                                                                                                                       
    loss, seed = min(                                                                                                                                                                                                 
  File "main.py", line 291, in <genexpr>                                                                                                                                                  
    initialize(                                                                                                                                                                                  
  File "/home/experiments/utils/initialization.py", line 80, in initialize                                                                                                        
    return svi.loss(model, guide, data)                                                                                                                                                                               
  File "/home/miniconda3/envs/py3-mink/lib/python3.8/site-packages/pyro_ppl-1.8.1+045f9e48-py3.8.egg/pyro/infer/tracegraph_elbo.py", line 276, in loss                                                    
    for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs):                                                                                                                                     
  File "/home/miniconda3/envs/py3-mink/lib/python3.8/site-packages/pyro_ppl-1.8.1+045f9e48-py3.8.egg/pyro/infer/elbo.py", line 178, in _get_traces                                                        
    self._guess_max_plate_nesting(model, guide, args, kwargs)                                                                                                                                                         
  File "/home/miniconda3/envs/py3-mink/lib/python3.8/site-packages/pyro_ppl-1.8.1+045f9e48-py3.8.egg/pyro/infer/elbo.py", line 116, in _guess_max_plate_nesting                                           
    model_trace.compute_log_prob()                                                                                                                                                                                    
  File "/home/miniconda3/envs/py3-mink/lib/python3.8/site-packages/pyro_ppl-1.8.1+045f9e48-py3.8.egg/pyro/poutine/trace_struct.py", line 236, in compute_log_prob                                         
    raise ValueError(                                                                                                                                                                                                 
  File "/home/miniconda3/envs/py3-mink/lib/python3.8/site-packages/pyro_ppl-1.8.1+045f9e48-py3.8.egg/pyro/poutine/trace_struct.py", line 230, in compute_log_prob                                         
    log_p = site["fn"].log_prob(                                                                                                                                                                                      
  File "/home/miniconda3/envs/py3-mink/lib/python3.8/site-packages/torch/distributions/transformed_distribution.py", line 138, in log_prob                                                                
    self._validate_sample(value)                                                                                                                                                                                      
  File "/home/miniconda3/envs/py3-mink/lib/python3.8/site-packages/torch/distributions/distribution.py", line 277, in _validate_sample                                                                    
    raise ValueError('The value argument must be within the support')                                                                                                                                                 
ValueError: Error while computing log_prob at site 'assignment':                                                                                                                                                      
The value argument must be within the support 

How can I make sure that the value argument must be within the support when I am using an AutoGuide?

you might try different values for the init_loc_fn argument, in particular init_loc_fn=init_to_median might work (the default for AutoDelta)

Thanks @martinjankowiak for the tips! I tried init_to_sample, init_to_median, init_to_feasible, init_to_generated, init_to_mean, init_to_uniform, and init_to_value. However, all of these initialization methods gave the abovementioned error.

More specifically, I used them in the following statement as init_method.

guide = AutoNormal(model, init_loc_fn=init_method)

Do you have any other tips?

nope but it would probably help if you provided a complete runnable script that triggers the error

Finally solved this!

TL;DR: using float64 instead of float32 solved the issue. Users probably should prefer this practice when auto guides and models containing latent variables sampled from distributions with simplex support are combined for SVI.

Background

I checked where the problem arises in AutoNormal. As assignment is a latent variable sampled from RelaxedOneHotCategorical, the value in this line should be simplex since this is the support of this distribution. However, I have observed that it is not simplex for some rows of the matrix (I checked using this statement). This is due to using biject_to as the transformation as it has a numerically unstable nature. When I changed the precision, things worked fine indeed. For example, I managed to overcome the situation by changing the relevant line with

value_double = transform(unconstrained_latent.double())
value = value_double.float()

since I did not want to set torch.set_default_dtype(torch.float64).

It could have been better if transform were transform_to for simplex instead of biject_to as transform_to is suggested for SVI. But I am also aware that SoftmaxTransform does not implement log_abs_det_jacobian yet… Therefore, the easiest solution seems to be the precision trick so far.