Hi, I’m starting with Pyro and Bayesian inference and I’ve been working in a SVI, I want to infer noise variables that have Categorical distribution and are conditioned, but when the algorithm runs, it doesn’t converge to approximate values of the params
Here is the code:
def t_mean(tens):
return torch.mean(torch.stack(tens), dim = 0)
def dummy(shape, scale, orien, posX, posY):
t = []
t.append(tensor([1]))
t.append(torch.nn.functional.one_hot(tensor(shape), 3))
t.append(torch.nn.functional.one_hot(tensor(scale), 6))
t.append(torch.nn.functional.one_hot(tensor(orien), 40))
t.append(torch.nn.functional.one_hot(tensor(posX), 32))
t.append(torch.nn.functional.one_hot(tensor(posY), 32))
return torch.cat(t, -1).to(torch.float32)
class SCM_test():
def init(self, vae):
self.vae = vae
self.latents_names = [‘shape’, ‘scale’, ‘orien’, ‘posX’, ‘posY’]
dist = Categorical
self.init_noise = {‘shape’: dist(tensor([.4, .4, .2])),
‘scale’: dist(tensor([1., 2., 2., 2., 2., 1.])),
‘orien’: dist(tensor(1/40).repeat(40)),
‘posY’: dist(tensor(1/40).repeat(40)),
‘posX’: dist(tensor(1/32).repeat(32)),
‘latent’: Normal(torch.zeros(200), torch.ones(200)),
‘img’: Uniform(torch.zeros(4096), torch.ones(4096))}
def f_latent(N_latent, shape, scale, orien, posX, posY):
shape = int(torch.round(shape))
scale = int(torch.round(scale))
orien = int(torch.round(orien))
posX = int(torch.round(posX))
posY = int(torch.round(posY))
ind = ind_from_att(0, shape, scale, orien, posX, posY)
label = dummy(shape, scale, orien, posX, posY)
mu, sigma = vae.encoder.forward(torch.tensor(imgs[ind]).reshape(4096).to(torch.float32), label)
return N_latent * sigma + mu
def f_image(N_img, latent, color, shape, scale, orien, posX, posY):
shape = int(torch.round(shape))
scale = int(torch.round(scale))
orien = int(torch.round(orien))
posX = int(torch.round(posX))
posY = int(torch.round(posY))
label = dummy(shape, scale, orien, posX, posY)
img_decode = vae.decoder.forward(latent, label)
return (N_img > img_decode).to(torch.float32)
def f_cat(N):
return N.to(torch.float32)
def f_posX(N, scale):
if 31 <= N+scale:
return tensor(31., dtype = torch.float32)
else:
return tensor(N+scale, dtype = torch.float32)
def model(noise = self.init_noise):
#noise variables
N_shape = pyro.sample('N_shape', noise['shape'])
N_scale = pyro.sample('N_scale', noise['scale'])
N_orien = pyro.sample('N_orien', noise['orien'])
N_posX = pyro.sample('N_posX', noise['posX'])
N_posY = pyro.sample('N_posY', noise['posY'])
N_latent = pyro.sample('N_latent', noise['latent'].to_event(1))
N_img = pyro.sample('N_img', noise['img'].to_event(1))
#variables
shape = pyro.sample('shape', Normal(f_cat(N_shape), tensor(0.01)))
scale = pyro.sample('scale', Normal(f_cat(N_scale), tensor(0.01)))
orien = pyro.sample('orien', Normal(f_cat(N_orien), tensor(0.01)))
posX = pyro.sample('posX', Normal(f_posX(N_posX, scale), tensor(0.01)))
posY = pyro.sample('posY', Normal(f_cat(N_posY), tensor(0.01)))
latent = pyro.sample('latent', Normal(f_latent(N_latent, shape, scale, orien, posX, posY), tensor(0.01)).to_event(1))
img = pyro.sample('img', Normal(f_image(N_img, latent, tensor([1.]), shape, scale, orien, posX, posY), tensor(0.01)).to_event(1))
return shape, scale, orien, posX, posY, latent, img
def guide(noise = self.init_noise):
#noise params
prob_shape = pyro.param('prob_shape', tensor(1/3).repeat(3), constraint = constraints.positive)
prob_scale = pyro.param('prob_scale', tensor(1/6).repeat(6), constraint = constraints.positive)
prob_orien = pyro.param('prob_orien', tensor(1/40).repeat(40), constraint = constraints.positive)
prob_posX = pyro.param('prob_posX', tensor(1/32).repeat(32), constraint = constraints.positive)
prob_posY = pyro.param('prob_posY', tensor(1/32).repeat(32), constraint = constraints.positive)
mu = {}
sigma = {}
mu['latent'] = pyro.param('mu_latent', torch.ones(200)*0.01)
sigma['latent'] = pyro.param('sigma_latent', torch.ones(200), constraint = constraints.positive)
mu['img'] = pyro.param('mu_img', torch.ones(4096)*.01, constraint = constraints.interval(0.001, 0.99))
sigma['img'] = pyro.param('sigma_img', torch.ones(4096), constraint = constraints.interval(0.001, 0.99))
#noise variables
N_shape = pyro.sample('N_shape', Categorical(prob_shape))
N_scale = pyro.sample('N_scale', Categorical(prob_scale))
N_orien = pyro.sample('N_orien', Categorical(prob_orien))
N_posX= pyro.sample('N_posX', Categorical(prob_posX))
N_posY= pyro.sample('N_posY', Categorical(prob_posY))
N_latent = pyro.sample('N_latent', Normal(mu['latent'], sigma['latent']).to_event(1))
N_img = pyro.sample('N_img', Uniform(mu['img'], sigma['img']).to_event(1))
#variables
shape = pyro.sample('shape', Normal(f_cat(N_shape), tensor(0.01)), infer={'is_auxiliary': True})
scale = pyro.sample('scale', Normal(f_cat(N_scale), tensor(0.01)), infer={'is_auxiliary': True})
orien = pyro.sample('orien', Normal(f_cat(N_orien), tensor(0.01)), infer={'is_auxiliary': True})
posX = pyro.sample('posX', Normal(f_cat(N_posX), tensor(0.01)), infer={'is_auxiliary': True})
posY = pyro.sample('posY', Normal(f_cat(N_posY), tensor(0.01)), infer={'is_auxiliary': True})
latent = pyro.sample('latent', Normal(f_latent(N_latent, shape, scale, orien, posX, posY), tensor(0.01)).to_event(1))
img = pyro.sample('img', Normal(f_image(N_img, latent, tensor([1.]), shape, scale, orien, posX, posY), tensor(0.01)).to_event(1))
return
self.model = model
self.guide = guide
def svi_calculate(scm, datacond, lr, betas, n_steps):
condModel = pyro.condition(scm.model, data = datacond)
pyro.clear_param_store()
#setting the SVI attributes
args = {‘lr’: lr, ‘betas’: betas}
optmizer = pyro.optim.Adam(args)
elbo = pyro.infer.Trace_ELBO()
svi = pyro.infer.SVI(condModel, scm.guide, optmizer, elbo)
var1 = [‘shape’, ‘scale’, ‘orien’, ‘posX’,‘posY’]
var2 = [‘latent’, ‘img’]
#training loop
cat_samples = {key:[] for key in var1}
mu_samples = {key : [] for key in var2}
sigma_samples = {key: [] for key in var2}
losses = []
for i in range(n_steps):
losses.append(svi.step(scm.init_noise))
#saving samples
for key in var1:
cat_samples[key].append(pyro.param(‘prob_’+key))
#print(key, cat_samples[key][-1])
for key in var2:
mu_samples[key].append(pyro.param(‘mu_’+key))
sigma_samples[key].append(pyro.param(‘sigma_’+key))
#print(key, mu_samples[key][-1])
#print(key, sigma_samples[key][-1])
for key in var1:
cat_samples[key] = t_mean(cat_samples[key])
for key in var2:
mu_samples[key] = t_mean(mu_samples[key])
sigma_samples[key] = t_mean(sigma_samples[key])
updated_noise = {key : Categorical(cat_samples[key]) for key in var1}
updated_noise[‘latent’] = Normal(mu_samples[‘latent’], sigma_samples[‘latent’])
updated_noise[‘img’] = Normal(mu_samples[‘img’], sigma_samples[‘img’])
samples = {key:[] for key in var1}
img_samples = []
for _ in range(1000):
sh, sc, ori, px, py, _, im = scm.model(updated_noise)
samples[‘shape’].append(sh)
samples[‘scale’].append(sc)
samples[‘orien’].append(ori)
samples[‘posX’].append(px)
samples[‘posY’].append(py)
img_samples.append(im)
for key in var1:
samples[key] = t_mean(samples[key])
img_samples = t_mean(img_samples)
print(’=====================’)
print(‘SVI’)
print('The SVI was run with ’ + str(n_steps) + ’ steps. The Adam lr was ’ + str(lr) + ’ and the betas interval was ’ + str(betas) + ‘.’)
print(‘Graph of loss:’)
plt.figure()
plt.plot(range(len(losses)), losses)
print(‘The conditioned values was:’)
print('Shape: '+ str(datacond[‘shape’]))
print('Scale: '+ str(datacond[‘scale’]))
print('Orien: '+ str(datacond[‘orien’]))
print('PosX: '+ str(datacond[‘posX’]))
print('PosY: '+ str(datacond[‘posY’]))
print(‘The infered values was:’)
print('Shape: '+ str(samples[‘shape’]))
print('Scale: '+ str(samples[‘scale’]))
print('Orien: '+ str(samples[‘orien’]))
print('PosX: '+ str(samples[‘posX’]))
print('PosY: '+ str(samples[‘posY’]))
print(‘The image sampled was:’)
plt.figure()
plt.imshow(img_samples.detach().numpy().reshape(64, 64))
plt.show()
The output of some of the prob params are:
‘shape’: tensor([0.3013, 0.3178, 2.4624], grad_fn=),
‘scale’: tensor([4.7633, 2.4051, 0.5139, 0.1793, 0.0792, 0.2573],
grad_fn=),