I’m new to pyro and want to implement a simple inverse graphics example involving estimating the vertices of a triangle drawn on a 32x32 black & white image.
I invented a model that samples 3 vertices, renders them and observes the results. I’m using SVI
with a AutoMultivariateNormal
guide (I also tried a convnet based guide, and inserting noise in the rendering process) to to the inference, but all I’m getting is inf
loss.
(What does work: using importance sampling + a slightly stochastic rendering procedure with normal noise, however, doing the same with SVI asymptoetes on a huge ELBO loss and seems that it does not learn anything)
code:
from torch import zeros, ones, tensor
from torch.nn import Sequential, Conv2d, ReLU, MaxPool2d, Module
from pyro import *
from pyro.distributions import *
from pyro.infer import SVI, Trace_ELBO, EmpiricalMarginal
from pyro.optim import Adam
from pyro.contrib.autoguide import AutoDiagonalNormal, AutoMultivariateNormal
import numpy as np
from matplotlib.pyplot import imshow, plot, hist, scatter
from tqdm import trange
def generate():
points = sample('points', Uniform(zeros(3,2),1))
return points
def render(points):
from PIL import Image, ImageDraw
size = 32
background = zeros((size, size))
background = background.numpy()
points = points.clone().detach().numpy()
img = Image.fromarray(background)
draw = ImageDraw.Draw(img)
draw.polygon(points * size, fill=255)
img = tensor(np.asarray(img)) / 255
return img
def show(arg):
if arg.shape[-1] == 2:
arg = render(arg)
imshow(arg.numpy(), cmap='gray')
def model(image=None):
points = generate()
img = render(points)
return sample('image', Delta(img), obs=image)
triangles = generate()
image = render(triangles)
show(image)
guide = AutoMultivariateNormal(model)
svi = SVI(model, guide, Adam(dict(lr=1e-3)), Trace_ELBO(), num_samples=1000)
plot([svi.step(image) for _ in trange(1000)])
run = svi.run(image)
points = EmpiricalMarginal(posterior, 'points')
samples = [points.sample() for _ in range(1000)]
print(samples)
scatter(samples[:,0],samples[:,1])
What am I missing here? How best to approach this problem?