Bug(s) in RelaxedOneHotCategorical?

I am fitting a very simple VAE model with two latent variables, one which specifies Normal latent exposures (original spec) and a second that specifies latent clusters (recent addition to model) with 40 clusters, a temperature of 0.35 and a dtype of float32. The prior was specifies using probs was passed to the RelaxedOneHotCategorical constructor as probs=XXX in the guide / model and the parameter that was being fit was passed to sample(…, obs=XXX) in the model.

The smoke test came together quickly (5 iteration) and then I began fitting the model with roughly 75 iterations.

Once it started fitting, I would see issues where some fits worked fine. But occasionally – increasingly frequent with the number of iterations which appears conditioned on the temperature value – I would get errors ending the fit. The probs passed to the Distribution had these values:

tensor([0.0250, 0.0250, 0.0250, 0.0250, 0.0250, 0.0250, 0.0250, 0.0250, 0.0250,
0.0250, 0.0250, 0.0250, 0.0250, 0.0250, 0.0250, 0.0250, 0.0250, 0.0250,
0.0250, 0.0250, 0.0250, 0.0250, 0.0250, 0.0250, 0.0250, 0.0250, 0.0250,
0.0250, 0.0250, 0.0250, 0.0250, 0.0250, 0.0250, 0.0250, 0.0250, 0.0250,
0.0250, 0.0250, 0.0250, 0.0250])

The fit params variable in sample(obs=XXX) had these values:
in sample(obs=XXX) had these values:

tensor([9.9970e-06, 9.9970e-06, 9.9970e-06, 9.9970e-06, 9.9970e-06, 9.9970e-06,
9.9970e-06, 9.9970e-06, 9.9970e-06, 9.9970e-06, 9.2487e-01, 9.9970e-06,
9.9970e-06, 2.0528e-04, 9.9970e-06, 9.9970e-06, 9.9970e-06, 9.9970e-06,
9.9970e-06, 9.9970e-06, 9.9970e-06, 9.9970e-06, 1.0555e-04, 9.9970e-06,
9.9970e-06, 9.9970e-06, 2.1660e-05, 6.1786e-02, 9.9970e-06, 9.9970e-06,
9.9970e-06, 9.9970e-06, 6.1573e-04, 9.9970e-06, 9.9970e-06, 9.9970e-06,
9.9970e-06, 9.9970e-06, 1.2062e-02, 9.9970e-06])

The error in this case was:

ValueError: Expected value argument (Tensor of shape (40,)) to be within the support (Simplex()) of the distribution RelaxedOneHotCategorical(), but found invalid values:

tensor([9.4518e-01, 1.1688e-11, 5.2286e-12, 5.6717e-11, 2.4537e-13, 4.0543e-14,
6.5703e-12, 5.0752e-12, 5.4818e-02, 4.2382e-11, 2.1359e-12, 1.0125e-13,
1.0480e-12, 2.0709e-12, 3.0333e-10, 1.3162e-13, 4.4372e-12, 3.1155e-12,
4.3695e-14, 1.3960e-12, 1.4140e-11, 1.0914e-11, 2.0808e-12, 4.7455e-12,
2.8321e-12, 3.8122e-12, 2.4209e-11, 4.7964e-13, 4.1868e-10, 5.1994e-12,
4.0029e-12, 2.9496e-12, 5.8078e-13, 8.1536e-11, 2.0167e-07, 1.7596e-11,
5.0181e-13, 5.7172e-09, 5.4377e-14, 1.8678e-08])

Some thoughts / observations:

  1. Raising the temperature reduces the frequent of what looks like a non-deterministic explosion

  2. The shift to the values in param to the probabilities indicated by RelaxedOneHotCategorical are obviously controlled by the temperature value

  3. Suppressing the error allows the fit to continue (the error will not reoccur on the next evaluation of svi.step, but obviously this is not a solution I would entertain.

  4. The probabilities in the prior and in the parameter object passed to XXX.sample(…, obs=param_k) both add to exactly one but the probabilities that come out of RelaxedOneHotCategorical.sample in the failing example sum to 0.9999960000000008

  5. Switching the dtype to float64 also reduces the frequency, but this also isn’t a solution.

Questions / Requests:

  1. What would be an example of an invalid value? What are the boundaries on this?

  2. What does “ValueError: Expected value argument (Tensor of shape (40,)) to be within the support (Simplex()) of the distribution RelaxedOneHotCategorical(), but found invalid values” really mean? It looks like the values on the least likely clusters are so small that there isn’t enough room for a valid Simplex.

  3. Would it be possible in such an example to provide the name of the parameter in the parameter store in such cases where the failing values are caused by a parameter (e.g. for context / forensics) The output lists all of the distributions being sampled but in my case there are one set of latent factors (the Normal distribution values) and one latent cluster map for each county in the United States (~ 4000 counties) so the list is 8000+ samples large. My interest is really in the one sample that caused the failure?

Getting a simple example of the failure has been challenging:

  1. Smoke test always works
  2. Fitting just a handful of counties has never required enough evaluations / draws to present the problem
  3. The exception data isn’t great in providing enough color to identify or reproduce the problem
    4.) Would adding upper / lower probabilities boundaries – in RelaxedOneHotCategorical, not in the parameter bijector – resolve the issue? In this case the simplex would still have room to search, but again that solution – specifying boundaries that ensure a valid Simplex – would depend on the choice for temperature.

Together it looks like there is a bit of code that gives rise to a numerical precision error in building the Simplex. Also the output of RelaxedOneHotCategorical.sample(…) should sum to one.