Warning when using GumbelSoftmaxReparam

I aim to create a model with RelaxedOneHotCategorical distribution and use it for SVI along with an autoguide. It has been stated in https://github.com/pyro-ppl/pyro/pull/2693 that GumbelSoftmaxReparam is intended to be used with some autoguides, like AutoNormal. An example has been shown in https://github.com/pyro-ppl/pyro/blob/dev/tests/infer/reparam/test_softmax.py#L29. Inspired by this example, I created the following setup

Toy example
import torch
import pyro
import pyro.distributions as dist
from pyro import poutine
from pyro.infer.reparam import GumbelSoftmaxReparam
from pyro.infer.autoguide import AutoNormal
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam
shape = (4,)
dim = 2

temperature = torch.tensor(0.1)
logits = torch.randn(shape + (dim,))
def model():
  with pyro.plate_stack("plates", shape):
      with pyro.plate("particles", 10000):
          pyro.sample("x", dist.RelaxedOneHotCategorical(temperature,
                                                          logits=logits))
guide = AutoNormal(model)
reparam_model = poutine.reparam(model, {"x": GumbelSoftmaxReparam()})

elbo = Trace_ELBO()
adam_params = {"lr": 0.001, "betas": (0.95, 0.999)}
optimizer = Adam(adam_params)

svi = SVI(
  reparam_model,
  guide,
  optimizer,
  loss=elbo,
)
for _ in range(100):
  loss = svi.step()

The example can be downloaded from here. Running this example raises the following warning.

/usr/local/lib/python3.7/dist-packages/pyro/util.py:291: UserWarning: Found non-auxiliary vars in guide but not model, consider marking these infer={'is_auxiliary': True}:
{'x'}
  guide_vars - aux_vars - model_vars
/usr/local/lib/python3.7/dist-packages/pyro/util.py:303: UserWarning: Found vars in model but not guide: {'x_uniform'}
  warnings.warn(f"Found vars in model but not guide: {bad_sites}")

Could anyone tell if I implemented anything wrong in my setup or if this is a bug?

Thanks in advance!

I think you’ll want to pass the reparameterized model to the guide:

- guide = AutoNormal(model)
+ guide = AutoNormal(reparam_model)

Oh, that makes sense :sweat_smile: thanks!