SVI params don't converge

Hi, I’m starting with Pyro and Bayesian inference and I’ve been working in a SVI, I want to infer noise variables that have Categorical distribution and are conditioned, but when the algorithm runs, it doesn’t converge to approximate values of the params

Here is the code:

def t_mean(tens):
return torch.mean(torch.stack(tens), dim = 0)

def dummy(shape, scale, orien, posX, posY):
t = []
t.append(tensor([1]))
t.append(torch.nn.functional.one_hot(tensor(shape), 3))
t.append(torch.nn.functional.one_hot(tensor(scale), 6))
t.append(torch.nn.functional.one_hot(tensor(orien), 40))
t.append(torch.nn.functional.one_hot(tensor(posX), 32))
t.append(torch.nn.functional.one_hot(tensor(posY), 32))
return torch.cat(t, -1).to(torch.float32)

class SCM_test():
def init(self, vae):
self.vae = vae
self.latents_names = [‘shape’, ‘scale’, ‘orien’, ‘posX’, ‘posY’]
dist = Categorical
self.init_noise = {‘shape’: dist(tensor([.4, .4, .2])),
‘scale’: dist(tensor([1., 2., 2., 2., 2., 1.])),
‘orien’: dist(tensor(1/40).repeat(40)),
‘posY’: dist(tensor(1/40).repeat(40)),
‘posX’: dist(tensor(1/32).repeat(32)),
‘latent’: Normal(torch.zeros(200), torch.ones(200)),
‘img’: Uniform(torch.zeros(4096), torch.ones(4096))}

def f_latent(N_latent, shape, scale, orien, posX, posY):
  shape = int(torch.round(shape))
  scale = int(torch.round(scale))
  orien = int(torch.round(orien))
  posX = int(torch.round(posX))
  posY = int(torch.round(posY))

  ind = ind_from_att(0, shape, scale, orien, posX, posY)
  label = dummy(shape, scale, orien, posX, posY)
  mu, sigma = vae.encoder.forward(torch.tensor(imgs[ind]).reshape(4096).to(torch.float32), label)
  return N_latent * sigma + mu

def f_image(N_img, latent, color, shape, scale, orien, posX, posY):
  shape = int(torch.round(shape))
  scale = int(torch.round(scale))
  orien = int(torch.round(orien))
  posX = int(torch.round(posX))
  posY = int(torch.round(posY))
  label = dummy(shape, scale, orien, posX, posY)
  img_decode = vae.decoder.forward(latent, label)
  return (N_img > img_decode).to(torch.float32)

def f_cat(N):
  return N.to(torch.float32)

def f_posX(N, scale):
  if 31 <= N+scale:
    return tensor(31., dtype = torch.float32)
  else:
    return tensor(N+scale, dtype = torch.float32)

def model(noise  = self.init_noise):
  #noise variables
  N_shape = pyro.sample('N_shape', noise['shape'])
  N_scale = pyro.sample('N_scale', noise['scale'])
  N_orien = pyro.sample('N_orien', noise['orien'])
  N_posX = pyro.sample('N_posX', noise['posX'])
  N_posY = pyro.sample('N_posY', noise['posY'])
  N_latent = pyro.sample('N_latent', noise['latent'].to_event(1))
  N_img = pyro.sample('N_img', noise['img'].to_event(1))
  #variables
  shape = pyro.sample('shape', Normal(f_cat(N_shape), tensor(0.01)))
  scale = pyro.sample('scale', Normal(f_cat(N_scale), tensor(0.01)))
  orien = pyro.sample('orien', Normal(f_cat(N_orien), tensor(0.01)))
  posX = pyro.sample('posX', Normal(f_posX(N_posX, scale), tensor(0.01)))
  posY = pyro.sample('posY', Normal(f_cat(N_posY), tensor(0.01)))

  latent = pyro.sample('latent', Normal(f_latent(N_latent, shape, scale, orien, posX, posY), tensor(0.01)).to_event(1))
  img = pyro.sample('img', Normal(f_image(N_img, latent, tensor([1.]), shape, scale, orien, posX, posY), tensor(0.01)).to_event(1))
  return shape, scale, orien, posX, posY, latent, img

def guide(noise = self.init_noise):
  #noise params
  prob_shape = pyro.param('prob_shape', tensor(1/3).repeat(3), constraint = constraints.positive)
  prob_scale = pyro.param('prob_scale', tensor(1/6).repeat(6), constraint = constraints.positive)
  prob_orien = pyro.param('prob_orien', tensor(1/40).repeat(40), constraint = constraints.positive)
  prob_posX = pyro.param('prob_posX', tensor(1/32).repeat(32), constraint = constraints.positive)
  prob_posY = pyro.param('prob_posY', tensor(1/32).repeat(32), constraint = constraints.positive)
  mu = {}
  sigma = {}
  mu['latent'] = pyro.param('mu_latent', torch.ones(200)*0.01)
  sigma['latent'] = pyro.param('sigma_latent', torch.ones(200), constraint = constraints.positive)
  mu['img'] = pyro.param('mu_img', torch.ones(4096)*.01, constraint = constraints.interval(0.001, 0.99))
  sigma['img'] = pyro.param('sigma_img', torch.ones(4096), constraint = constraints.interval(0.001, 0.99))
  #noise variables
  N_shape = pyro.sample('N_shape', Categorical(prob_shape))
  N_scale = pyro.sample('N_scale', Categorical(prob_scale))
  N_orien = pyro.sample('N_orien', Categorical(prob_orien))
  N_posX= pyro.sample('N_posX', Categorical(prob_posX))
  N_posY= pyro.sample('N_posY', Categorical(prob_posY))
  N_latent = pyro.sample('N_latent', Normal(mu['latent'], sigma['latent']).to_event(1))
  N_img = pyro.sample('N_img', Uniform(mu['img'], sigma['img']).to_event(1))
  #variables
  shape = pyro.sample('shape', Normal(f_cat(N_shape), tensor(0.01)), infer={'is_auxiliary': True})
  scale = pyro.sample('scale', Normal(f_cat(N_scale), tensor(0.01)), infer={'is_auxiliary': True})
  orien = pyro.sample('orien', Normal(f_cat(N_orien), tensor(0.01)), infer={'is_auxiliary': True})
  posX = pyro.sample('posX', Normal(f_cat(N_posX), tensor(0.01)), infer={'is_auxiliary': True})
  posY = pyro.sample('posY', Normal(f_cat(N_posY), tensor(0.01)), infer={'is_auxiliary': True})
  latent = pyro.sample('latent', Normal(f_latent(N_latent, shape, scale, orien, posX, posY), tensor(0.01)).to_event(1))
  img = pyro.sample('img', Normal(f_image(N_img, latent, tensor([1.]), shape, scale, orien, posX, posY), tensor(0.01)).to_event(1))
  return

self.model = model
self.guide = guide

def svi_calculate(scm, datacond, lr, betas, n_steps):
condModel = pyro.condition(scm.model, data = datacond)
pyro.clear_param_store()

#setting the SVI attributes
args = {‘lr’: lr, ‘betas’: betas}
optmizer = pyro.optim.Adam(args)
elbo = pyro.infer.Trace_ELBO()
svi = pyro.infer.SVI(condModel, scm.guide, optmizer, elbo)
var1 = [‘shape’, ‘scale’, ‘orien’, ‘posX’,‘posY’]
var2 = [‘latent’, ‘img’]
#training loop
cat_samples = {key:[] for key in var1}
mu_samples = {key : [] for key in var2}
sigma_samples = {key: [] for key in var2}
losses = []
for i in range(n_steps):
losses.append(svi.step(scm.init_noise))
#saving samples
for key in var1:
cat_samples[key].append(pyro.param(‘prob_’+key))
#print(key, cat_samples[key][-1])
for key in var2:
mu_samples[key].append(pyro.param(‘mu_’+key))
sigma_samples[key].append(pyro.param(‘sigma_’+key))
#print(key, mu_samples[key][-1])
#print(key, sigma_samples[key][-1])
for key in var1:
cat_samples[key] = t_mean(cat_samples[key])
for key in var2:
mu_samples[key] = t_mean(mu_samples[key])
sigma_samples[key] = t_mean(sigma_samples[key])

updated_noise = {key : Categorical(cat_samples[key]) for key in var1}
updated_noise[‘latent’] = Normal(mu_samples[‘latent’], sigma_samples[‘latent’])
updated_noise[‘img’] = Normal(mu_samples[‘img’], sigma_samples[‘img’])
samples = {key:[] for key in var1}
img_samples = []
for _ in range(1000):
sh, sc, ori, px, py, _, im = scm.model(updated_noise)
samples[‘shape’].append(sh)
samples[‘scale’].append(sc)
samples[‘orien’].append(ori)
samples[‘posX’].append(px)
samples[‘posY’].append(py)
img_samples.append(im)
for key in var1:
samples[key] = t_mean(samples[key])
img_samples = t_mean(img_samples)
print(’=====================’)
print(‘SVI’)
print('The SVI was run with ’ + str(n_steps) + ’ steps. The Adam lr was ’ + str(lr) + ’ and the betas interval was ’ + str(betas) + ‘.’)
print(‘Graph of loss:’)
plt.figure()
plt.plot(range(len(losses)), losses)
print(‘The conditioned values was:’)
print('Shape: '+ str(datacond[‘shape’]))
print('Scale: '+ str(datacond[‘scale’]))
print('Orien: '+ str(datacond[‘orien’]))
print('PosX: '+ str(datacond[‘posX’]))
print('PosY: '+ str(datacond[‘posY’]))
print(‘The infered values was:’)
print('Shape: '+ str(samples[‘shape’]))
print('Scale: '+ str(samples[‘scale’]))
print('Orien: '+ str(samples[‘orien’]))
print('PosX: '+ str(samples[‘posX’]))
print('PosY: '+ str(samples[‘posY’]))
print(‘The image sampled was:’)
plt.figure()
plt.imshow(img_samples.detach().numpy().reshape(64, 64))
plt.show()

The output of some of the prob params are:

‘shape’: tensor([0.3013, 0.3178, 2.4624], grad_fn=),
‘scale’: tensor([4.7633, 2.4051, 0.5139, 0.1793, 0.0792, 0.2573],
grad_fn=),

Hi @GiovaniValdrighi, could you please format your code? It would be useful to have some words on what is your expected result and what seems to go wrong.