Argument must be within the support

Hello, everyone
I’m kinda new with the Pyro package and have been working on a variational autoencoder with the DeepMind Dsprites dataset. I’m trying to do a SVI on my model and guide but I keep getting bugs, and now I got “Error while computing log_prob at site ‘N_img’: The value argument must be within the support” and don’t know how to fix it, the variable N_img have a Uniform distribution [0,1] in both model and guide.

Here is the model and guide:

class SCM():
  def __init__(self, vae):
    self.vae = vae
    self.latents_names = ['color', 'shape', 'scale', 'orien', 'posX', 'posY']
    self.latents_sizes = {'color' : 1, 'shape': 3, 'scale':6, 'orien': 40, 'posX': 32, 'posY':32}
    temp = tensor([0.1])
    dist = RelaxedOneHotCategorical
    self.init_noise = {'color': dist(temp,probs = torch.tensor([1.])),
                       'shape': dist(temp,probs = torch.tensor([0.4, 0.4, 0.2])),
                       'scale': dist(temp,probs = torch.tensor([1/6]).repeat(6)),
                       'orien': dist(temp,probs = torch.tensor([1/40]).repeat(40)),
                       'posX': dist(temp,probs = torch.tensor([1/32]).repeat(32)),
                       'posY': dist(temp,probs = torch.tensor([1/32]).repeat(32)),
                       'latent': Normal(torch.zeros(200), torch.ones(200)),
                       'img': Uniform(torch.zeros(4096), torch.ones(4096))}
    #functions for the model
    def f_gumbel(N):
      return N
    def f_posX(N, scale):
      if 31 <= int(scale.max(0)[1] + N.max(0)[1]):
        return torch.nn.functional.one_hot(torch.tensor([31]), 32).to(torch.float32).reshape([32])
        return torch.nn.functional.one_hot(scale.max(0)[1] + N.max(0)[1], 32).to(torch.float32).reshape([32])

    def f_latent(N_latent, color, shape, scale, orien, posX, posY):
      ind = ind_from_att(color.max(0)[1], shape.max(0)[1], scale.max(0)[1], orien.max(0)[1], posX.max(0)[1], posY.max(0)[1])
      label = torch.round([color, shape, scale, orien, posX, posY], -1))
      mu, sigma = vae.encoder.forward(torch.tensor(imgs[ind]).reshape(4096).to(torch.float32), label)
      return N_latent * sigma + mu

    def f_image(N_img, latent, color, shape, scale, orien, posX, posY):
      label = torch.round([color, shape, scale, orien, posX, posY], -1))
      img_decode = vae.decoder.forward(latent, label)
      return (N_img > img_decode).to(torch.float)

    def model(noise = self.init_noise):

      #Noise variables
      N_color = pyro.sample('N_color', noise['color'])
      N_shape = pyro.sample('N_shape', noise['shape'])
      N_scale = pyro.sample('N_scale', noise['scale'])
      N_orien = pyro.sample('N_orien', noise['orien'])
      N_posX = pyro.sample('N_posX', noise['posX'])
      N_posY = pyro.sample('N_posY', noise['posY'])
      N_latent = pyro.sample('N_latent', noise['latent'].to_event(1))
      N_img = pyro.sample('N_img', noise['img'].to_event(1))
      #print(N_color, N_shape, N_scale, N_orien, N_posX, N_posY)
      color = pyro.sample('color', Normal(f_gumbel(N_color), torch.tensor([0.01])).to_event(1))
      shape = pyro.sample('shape', Normal(f_gumbel(N_shape), torch.tensor([0.01])).to_event(1))
      scale = pyro.sample('scale', Normal(f_gumbel(N_scale), torch.tensor([0.01])).to_event(1))
      orien = pyro.sample('orien', Normal(f_gumbel(N_orien), torch.tensor([0.01])).to_event(1))
      posX = pyro.sample('posX', Normal(f_posX(N_posX, scale), torch.tensor([0.01])).to_event(1))
      posY = pyro.sample('posY', Normal(f_gumbel(N_posY), torch.tensor([0.01])).to_event(1))

      latent = pyro.sample('latent', Normal(f_latent(N_latent, color, shape, scale, orien, posX, posY), torch.tensor([0.01])).to_event(1))
      img = pyro.sample('img', Normal(f_image(N_img, latent, color, shape, scale, orien, posX, posY), torch.tensor([0.01])).to_event(1))

      return torch.round([color, shape, scale, orien, posX, posY], -1)), latent, img
    def guide(noise = self.init_noise):
      cat_val =  {'color': 1, 'shape': 3, 'scale': 6, 'orien': 40, 'posX': 32, 'posY': 32}
      temp = {key : pyro.param('temp_'+key, tensor([0.4]), constraint = constraints.positive) for key in cat_val}
      probs = {key : pyro.param('prob_'+key, torch.ones(value)*0.5, constraint = constraints.interval(tensor(0.), tensor(1.))) for key, value in cat_val.items()}

      mu_latent = pyro.param('mu_latent', torch.ones(200)*0.5)
      sigma_latent = pyro.param('sigma_latent', torch.ones(200), constraint = constraints.positive)

      low_img = pyro.param('low_img', torch.zeros(4096), constraint = constraints.positive)
      high_img = pyro.param('high_img', torch.ones(4096), constraint = constraints.positive)
      for key in cat_val.keys():
        pyro.sample('N_'+key, RelaxedOneHotCategorical(temp[key], probs[key]))

      pyro.sample('N_latent', Normal(mu_latent, sigma_latent).to_event(1))
      print(pyro.sample('N_img', Uniform(low_img, high_img).to_event(1)))


    def viz_model():
      label, _, img = model()
      label = label_from_dummy(label)
      plt.imshow(img.detach().numpy().reshape(64, 64), cmap = 'Greys')
      text = 'Color:0    Shape:'+str(int(label[1])) + '   Scale:' + str(int(label[2])) + '   Orien.:' + str(int(label[3])) + '   Pos.X:' + str(int(label[4])) + '   Pos.Y:' + str(int(label[5]))

    self.model = model = guide
    self.viz_model = viz_model

scm = SCM(vae)

Hi @GiovaniValdrighi, at first sight I see low_img should be positive but is initialized to zero; try 1e-3 or something. Also, I’ve actually never tried to use a Uniform distribution in a guide, YMMV. More generally, I’d open up a debugger and try printing out the param value that is erroring, probably print self.ub in the debugger.