Simple inverse graphics example with SVI not working

I’m new to pyro and want to implement a simple inverse graphics example involving estimating the vertices of a triangle drawn on a 32x32 black & white image.
I invented a model that samples 3 vertices, renders them and observes the results. I’m using SVI with a AutoMultivariateNormal guide (I also tried a convnet based guide, and inserting noise in the rendering process) to to the inference, but all I’m getting is inf loss.

(What does work: using importance sampling + a slightly stochastic rendering procedure with normal noise, however, doing the same with SVI asymptoetes on a huge ELBO loss and seems that it does not learn anything)

code:

from torch import zeros, ones, tensor
from torch.nn import Sequential, Conv2d, ReLU, MaxPool2d, Module
from pyro import *
from pyro.distributions import *
from pyro.infer import SVI, Trace_ELBO, EmpiricalMarginal
from pyro.optim import Adam
from pyro.contrib.autoguide import AutoDiagonalNormal, AutoMultivariateNormal
import numpy as np
from matplotlib.pyplot import imshow, plot, hist, scatter
from tqdm import trange

def generate():
    points = sample('points', Uniform(zeros(3,2),1))
    return points
    
def render(points):
    from PIL import Image, ImageDraw
    size = 32
    background = zeros((size, size))
    background = background.numpy()
    points = points.clone().detach().numpy()
    img = Image.fromarray(background)
    draw = ImageDraw.Draw(img)
    draw.polygon(points * size, fill=255)
    img = tensor(np.asarray(img)) / 255
    return img

def show(arg):
    if arg.shape[-1] == 2:
        arg = render(arg)
    imshow(arg.numpy(), cmap='gray')
    

def model(image=None):
    points = generate()
    img = render(points)
    return sample('image', Delta(img), obs=image)

triangles = generate()
image = render(triangles)
show(image)

guide = AutoMultivariateNormal(model)
svi = SVI(model, guide, Adam(dict(lr=1e-3)), Trace_ELBO(), num_samples=1000)
plot([svi.step(image) for _ in trange(1000)])
run = svi.run(image)

points = EmpiricalMarginal(posterior, 'points')
samples = [points.sample() for _ in range(1000)]
print(samples)
scatter(samples[:,0],samples[:,1])

What am I missing here? How best to approach this problem?

Hey there – I’ve never worked on inverse graphics problems before, but a couple quick thoughts:

  • Define your own guide which places a Bernoulli prior over your points. The latent variables you’re interested in understanding the posterior of (pixel-wise activations, that is) are not Normally distributed. Instead, you probably want to model them as Bernoulli distributed.
  • I’ve seen some posts on this forum that a Delta prior can lead to numerical instability, and I’ve experienced that myself. You may want to use something like sample('image', pyro.distributions.Bernoulli(probs=img, obs=image) with all values of 1 in img replaced with 0.99999.
  • I’m actually interested in the coordinates of the triangle and not the image itself (pixel activations).
  • That’s a case of adding some noise to the rendering process which as I explained simply yields to a finite (bug huge) ELBO loss that doesn’t seem to decrease (Although it works with Importance sampling).

I’d also recommend starting with an AutoDelta guide (i.e. MAP inference), then move to an AutoNormal guide (i.e. mean field variational inference), and finally try an AutoMultivariateNormal (correlated variational inference).