Thanks!
Here is my code
img_length = 75 # taken from rela_net script
img_size = img_length*img_length
class ProbModel(nn.Module):
def __init__(self, size=5, use_cuda=False):
super(ProbModel, self).__init__()
# self.encoder = Encoder()
# self.colors = [
# (0, 0, 255), ##r
# (0, 255, 0), ##g
# (255, 0, 0), ##b
# (0, 156, 255), ##o
# (128, 128, 128), ##k
# (0, 255, 255) ##y
# ]
self.colors = [
(0, 0, 255), ##r
(0, 255, 0), ##g
# (255, 0, 0), ##b
# (0, 156, 255), ##o
# (128, 128, 128), ##k
# (0, 255, 255) ##y
]
self.color_names = ['red', 'green', 'blue', 'orange', 'black', 'yellow']
self.size = size # object size
if use_cuda:
self.cuda()
self.use_cuda = use_cuda
def deterministic_image(self, idxs, idys, objtypes):
size = self.size
colors = self.colors
color_names = self.color_names
img = np.ones((img_length, img_length, 3)) * 255
for color_id, color in enumerate(colors):
color_name = color_names[color_id]
idx = idxs[color_id]
idy = idys[color_id]
obj_type = objtypes[color_id]
center = [idx, idy]
if obj_type == 1:
start = (center[0] - size, center[1] - size)
end = (center[0] + size, center[1] + size)
cv2.rectangle(img, start, end, color, -1)
else:
center_ = (center[0], center[1])
cv2.circle(img, center_, size, color, -1)
img = img / 255.
return img
# define the model p(x|z)p(z)
def model(self, x, observe=True):
size = self.size
colors = self.colors
with pyro.iarange("batch", x.shape[0], use_cuda=self.use_cuda):
# with pyro.iarange("colors", len(colors)):
pixel_probs = x.new_ones(torch.Size((x.shape[0], len(colors), img_length-(2*size))))
obj_probs = x.new_ones(torch.Size((x.shape[0], len(colors)))) * 0.5
idxs = size + pyro.sample("idx", dist.Categorical(pixel_probs).independent(1))
idys = size + pyro.sample("idy", dist.Categorical(pixel_probs).independent(1))
objtypes = pyro.sample("obj_type", dist.Bernoulli(obj_probs).independent(1))
img = x.new_zeros(torch.Size((x.shape[0], img_length, img_length, 3)))
for i in range(x.shape[0]):
img_i = self.deterministic_image(idxs[i], idys[i], objtypes[i])
# plt.imshow(img_i)
# plt.show()
img[i] = torch.from_numpy(img_i)
#img -> batch x 75 x 75 x 3 in [0,1]
img = img.transpose(3,2).transpose(1,2)
#img -> batch x 3 x 75 x 75 in [0,1]
img_flat = img.contiguous().view(img.shape[0],-1)
if observe:
obs = pyro.sample("obs", dist.Normal(img_flat, 0.1).independent(1),
obs=x.reshape(img.size(0), -1))
return img.detach().cpu().numpy()
# define the guide (i.e. variational distribution) q(z|x)
def guide(self, x):
size = self.size
colors = self.colors
# pyro.module("encoder", self.encoder) #IMP STEP################
with pyro.iarange("batch", x.shape[0], use_cuda=self.use_cuda):
# with pyro.iarange("colors", len(colors)):
pixel_probs = pyro.param("idx_param", x.new_ones(torch.Size((x.shape[0], len(colors), img_length-(2*size)))))
obj_probs = pyro.param("idy_param", x.new_ones(torch.Size((x.shape[0], len(colors)))) * 0.5)
idxs = pyro.sample("idx", dist.Categorical(pixel_probs).independent(1))
idys = pyro.sample("idy", dist.Categorical(pixel_probs).independent(1))
objtypes = pyro.sample("obj_type", dist.Bernoulli(obj_probs).independent(1))
return idxs, idys, objtypes
loaderTensor = transforms.Compose([
transforms.ToTensor()])
def load_image(filename):
image = Image.open(filename).convert('RGB')
image_var = loaderTensor(image).unsqueeze(0).float()
return image_var
USE_CUDA=False
target_img = load_image('../data/target.png')
# plt.imshow(target_img.squeeze().transpose(0,1).transpose(1,2))
# plt.show()
probmodel = ProbModel(use_cuda=USE_CUDA)
# sampled_img = probmodel.model(target_img, observe=False)
# plt.imshow(sampled_img.squeeze().transpose(0,1).transpose(1,2))
# plt.imshow(sampled_img.squeeze().swapaxes(0,1).swapaxes(1,2))
# plt.show()
# condition_img = torch.from_numpy(sampled_img)
condition_img = target_img
if USE_CUDA:
target_img = target_img.cuda()
condition_img= condition_img.cuda()
num_samples = 60000
start_time = time.time()
posterior = pyro.infer.Importance(probmodel.model, num_samples = num_samples).run(condition_img, False)
marginal_loc = pyro.infer.EmpiricalMarginal(posterior, sites=['idx', 'idy'])
marginal_obj = pyro.infer.EmpiricalMarginal(posterior, sites=['obj_type'])
print('Time taken %s'%(time.time()-start_time))
# Draw samples from marginal
fig = plt.figure()
for rr in range(4):
[idx], [idy] = marginal_loc()
[[obj_type]] = marginal_obj()
inf_img = probmodel.deterministic_image(idx, idy, obj_type)
ax = fig.add_subplot(2,2,rr+1)
if rr ==0:
ax.imshow(condition_img.squeeze().transpose(0,1).transpose(1,2))
else:
ax.imshow(inf_img)
plt.suptitle('Samples from no-conditioning empirical marginal - ' + str(num_samples))
fig.set_figheight(7)
plt.show()