Simple inverse graphics example with SVI not working


#1

I’m new to pyro and want to implement a simple inverse graphics example involving estimating the vertices of a triangle drawn on a 32x32 black & white image.
I invented a model that samples 3 vertices, renders them and observes the results. I’m using SVI with a AutoMultivariateNormal guide (I also tried a convnet based guide, and inserting noise in the rendering process) to to the inference, but all I’m getting is inf loss.

(What does work: using importance sampling + a slightly stochastic rendering procedure with normal noise, however, doing the same with SVI asymptoetes on a huge ELBO loss and seems that it does not learn anything)

code:

from torch import zeros, ones, tensor
from torch.nn import Sequential, Conv2d, ReLU, MaxPool2d, Module
from pyro import *
from pyro.distributions import *
from pyro.infer import SVI, Trace_ELBO, EmpiricalMarginal
from pyro.optim import Adam
from pyro.contrib.autoguide import AutoDiagonalNormal, AutoMultivariateNormal
import numpy as np
from matplotlib.pyplot import imshow, plot, hist, scatter
from tqdm import trange

def generate():
    points = sample('points', Uniform(zeros(3,2),1))
    return points
    
def render(points):
    from PIL import Image, ImageDraw
    size = 32
    background = zeros((size, size))
    background = background.numpy()
    points = points.clone().detach().numpy()
    img = Image.fromarray(background)
    draw = ImageDraw.Draw(img)
    draw.polygon(points * size, fill=255)
    img = tensor(np.asarray(img)) / 255
    return img

def show(arg):
    if arg.shape[-1] == 2:
        arg = render(arg)
    imshow(arg.numpy(), cmap='gray')
    

def model(image=None):
    points = generate()
    img = render(points)
    return sample('image', Delta(img), obs=image)

triangles = generate()
image = render(triangles)
show(image)

guide = AutoMultivariateNormal(model)
svi = SVI(model, guide, Adam(dict(lr=1e-3)), Trace_ELBO(), num_samples=1000)
plot([svi.step(image) for _ in trange(1000)])
run = svi.run(image)

points = EmpiricalMarginal(posterior, 'points')
samples = [points.sample() for _ in range(1000)]
print(samples)
scatter(samples[:,0],samples[:,1])

What am I missing here? How best to approach this problem?


#2

Hey there – I’ve never worked on inverse graphics problems before, but a couple quick thoughts:

  • Define your own guide which places a Bernoulli prior over your points. The latent variables you’re interested in understanding the posterior of (pixel-wise activations, that is) are not Normally distributed. Instead, you probably want to model them as Bernoulli distributed.
  • I’ve seen some posts on this forum that a Delta prior can lead to numerical instability, and I’ve experienced that myself. You may want to use something like sample('image', pyro.distributions.Bernoulli(probs=img, obs=image) with all values of 1 in img replaced with 0.99999.

#3
  • I’m actually interested in the coordinates of the triangle and not the image itself (pixel activations).
  • That’s a case of adding some noise to the rendering process which as I explained simply yields to a finite (bug huge) ELBO loss that doesn’t seem to decrease (Although it works with Importance sampling).

#4

I’d also recommend starting with an AutoDelta guide (i.e. MAP inference), then move to an AutoNormal guide (i.e. mean field variational inference), and finally try an AutoMultivariateNormal (correlated variational inference).