Hello, everyone
I’m kinda new with the Pyro package and have been working on a variational autoencoder with the DeepMind Dsprites dataset. I’m trying to do a SVI on my model and guide but I keep getting bugs, and now I got “Error while computing log_prob at site ‘N_img’: The value argument must be within the support” and don’t know how to fix it, the variable N_img have a Uniform distribution [0,1] in both model and guide.
Here is the model and guide:
#scm
class SCM():
def __init__(self, vae):
self.vae = vae
self.latents_names = ['color', 'shape', 'scale', 'orien', 'posX', 'posY']
self.latents_sizes = {'color' : 1, 'shape': 3, 'scale':6, 'orien': 40, 'posX': 32, 'posY':32}
temp = tensor([0.1])
dist = RelaxedOneHotCategorical
self.init_noise = {'color': dist(temp,probs = torch.tensor([1.])),
'shape': dist(temp,probs = torch.tensor([0.4, 0.4, 0.2])),
'scale': dist(temp,probs = torch.tensor([1/6]).repeat(6)),
'orien': dist(temp,probs = torch.tensor([1/40]).repeat(40)),
'posX': dist(temp,probs = torch.tensor([1/32]).repeat(32)),
'posY': dist(temp,probs = torch.tensor([1/32]).repeat(32)),
'latent': Normal(torch.zeros(200), torch.ones(200)),
'img': Uniform(torch.zeros(4096), torch.ones(4096))}
#functions for the model
def f_gumbel(N):
return N
def f_posX(N, scale):
if 31 <= int(scale.max(0)[1] + N.max(0)[1]):
return torch.nn.functional.one_hot(torch.tensor([31]), 32).to(torch.float32).reshape([32])
else:
return torch.nn.functional.one_hot(scale.max(0)[1] + N.max(0)[1], 32).to(torch.float32).reshape([32])
def f_latent(N_latent, color, shape, scale, orien, posX, posY):
ind = ind_from_att(color.max(0)[1], shape.max(0)[1], scale.max(0)[1], orien.max(0)[1], posX.max(0)[1], posY.max(0)[1])
label = torch.round(torch.cat([color, shape, scale, orien, posX, posY], -1))
mu, sigma = vae.encoder.forward(torch.tensor(imgs[ind]).reshape(4096).to(torch.float32), label)
return N_latent * sigma + mu
def f_image(N_img, latent, color, shape, scale, orien, posX, posY):
label = torch.round(torch.cat([color, shape, scale, orien, posX, posY], -1))
img_decode = vae.decoder.forward(latent, label)
return (N_img > img_decode).to(torch.float)
def model(noise = self.init_noise):
#Noise variables
N_color = pyro.sample('N_color', noise['color'])
N_shape = pyro.sample('N_shape', noise['shape'])
N_scale = pyro.sample('N_scale', noise['scale'])
N_orien = pyro.sample('N_orien', noise['orien'])
N_posX = pyro.sample('N_posX', noise['posX'])
N_posY = pyro.sample('N_posY', noise['posY'])
N_latent = pyro.sample('N_latent', noise['latent'].to_event(1))
N_img = pyro.sample('N_img', noise['img'].to_event(1))
#print(N_color, N_shape, N_scale, N_orien, N_posX, N_posY)
#variables
color = pyro.sample('color', Normal(f_gumbel(N_color), torch.tensor([0.01])).to_event(1))
shape = pyro.sample('shape', Normal(f_gumbel(N_shape), torch.tensor([0.01])).to_event(1))
scale = pyro.sample('scale', Normal(f_gumbel(N_scale), torch.tensor([0.01])).to_event(1))
orien = pyro.sample('orien', Normal(f_gumbel(N_orien), torch.tensor([0.01])).to_event(1))
posX = pyro.sample('posX', Normal(f_posX(N_posX, scale), torch.tensor([0.01])).to_event(1))
posY = pyro.sample('posY', Normal(f_gumbel(N_posY), torch.tensor([0.01])).to_event(1))
#variables
latent = pyro.sample('latent', Normal(f_latent(N_latent, color, shape, scale, orien, posX, posY), torch.tensor([0.01])).to_event(1))
img = pyro.sample('img', Normal(f_image(N_img, latent, color, shape, scale, orien, posX, posY), torch.tensor([0.01])).to_event(1))
return torch.round(torch.cat([color, shape, scale, orien, posX, posY], -1)), latent, img
def guide(noise = self.init_noise):
cat_val = {'color': 1, 'shape': 3, 'scale': 6, 'orien': 40, 'posX': 32, 'posY': 32}
temp = {key : pyro.param('temp_'+key, tensor([0.4]), constraint = constraints.positive) for key in cat_val}
probs = {key : pyro.param('prob_'+key, torch.ones(value)*0.5, constraint = constraints.interval(tensor(0.), tensor(1.))) for key, value in cat_val.items()}
mu_latent = pyro.param('mu_latent', torch.ones(200)*0.5)
sigma_latent = pyro.param('sigma_latent', torch.ones(200), constraint = constraints.positive)
low_img = pyro.param('low_img', torch.zeros(4096), constraint = constraints.positive)
high_img = pyro.param('high_img', torch.ones(4096), constraint = constraints.positive)
for key in cat_val.keys():
pyro.sample('N_'+key, RelaxedOneHotCategorical(temp[key], probs[key]))
pyro.sample('N_latent', Normal(mu_latent, sigma_latent).to_event(1))
print(pyro.sample('N_img', Uniform(low_img, high_img).to_event(1)))
return
def viz_model():
label, _, img = model()
label = label_from_dummy(label)
plt.imshow(img.detach().numpy().reshape(64, 64), cmap = 'Greys')
text = 'Color:0 Shape:'+str(int(label[1])) + ' Scale:' + str(int(label[2])) + ' Orien.:' + str(int(label[3])) + ' Pos.X:' + str(int(label[4])) + ' Pos.Y:' + str(int(label[5]))
plt.title(text)
plt.show()
self.model = model
self.guide = guide
self.viz_model = viz_model
scm = SCM(vae)
scm.viz_model()