I’m new to pyro and want to implement a simple inverse graphics example involving estimating the vertices of a triangle drawn on a 32x32 black & white image.
I invented a model that samples 3 vertices, renders them and observes the results. I’m using
SVI with a
AutoMultivariateNormal guide (I also tried a convnet based guide, and inserting noise in the rendering process) to to the inference, but all I’m getting is
(What does work: using importance sampling + a slightly stochastic rendering procedure with normal noise, however, doing the same with SVI asymptoetes on a huge ELBO loss and seems that it does not learn anything)
from torch import zeros, ones, tensor from torch.nn import Sequential, Conv2d, ReLU, MaxPool2d, Module from pyro import * from pyro.distributions import * from pyro.infer import SVI, Trace_ELBO, EmpiricalMarginal from pyro.optim import Adam from pyro.contrib.autoguide import AutoDiagonalNormal, AutoMultivariateNormal import numpy as np from matplotlib.pyplot import imshow, plot, hist, scatter from tqdm import trange def generate(): points = sample('points', Uniform(zeros(3,2),1)) return points def render(points): from PIL import Image, ImageDraw size = 32 background = zeros((size, size)) background = background.numpy() points = points.clone().detach().numpy() img = Image.fromarray(background) draw = ImageDraw.Draw(img) draw.polygon(points * size, fill=255) img = tensor(np.asarray(img)) / 255 return img def show(arg): if arg.shape[-1] == 2: arg = render(arg) imshow(arg.numpy(), cmap='gray') def model(image=None): points = generate() img = render(points) return sample('image', Delta(img), obs=image) triangles = generate() image = render(triangles) show(image) guide = AutoMultivariateNormal(model) svi = SVI(model, guide, Adam(dict(lr=1e-3)), Trace_ELBO(), num_samples=1000) plot([svi.step(image) for _ in trange(1000)]) run = svi.run(image) points = EmpiricalMarginal(posterior, 'points') samples = [points.sample() for _ in range(1000)] print(samples) scatter(samples[:,0],samples[:,1])
What am I missing here? How best to approach this problem?