Implementation: Generative Probabilistic Graphics Programs

Hello

As an exercise I would like to implement a Generative Probabilistic Graphics Programs to read
simple captcha graphics as described here:

“Approximate Bayesian Image Interpretation using Generative Probabilistic Graphics Programs”
Vikash K. Mansinghka, Tejas D. Kulkarni, Yura N. Perov, Joshua B. Tenenbaum

The paper includes some pseudo code for Church.
Based on this I would like to implement it using Numpyro.

My plan so far is to start with Numpyro’s VAE tutorial:

If I understand correctly, the key difference betweeb a VAE and the model proposed in the paper is that in the latter we actually already have a top-down defined model of how we think a given captcha should be constructed. In particular, this is expressed in the latent variables (e.g. dedicated variables for glyph identity and size), while in a VAE the latent space is not necessarily linked directly to such “real-world” variables but rather learnt implicitly. Based on this idea, I thought I could replace the encoder with a handcrafted model including latent variables for glyph-identities, sizes and rotations etc.
Python’s Pillow library could serve as a simple render pipeline.
A likelihood model could then asses the probability of the underlying latent variables, given some observed captcha image. Given, priors for latent variables and the likelihood, using variational inference, we could then try to approximate the posterior and thus the most likely glyph-identities of the observed captcha.

As I am rather new to probabilistic programming and Numpyro, I really would appreciate some help here:

My current questions would be:

  1. Does my plan make sense so far? :slight_smile:

  2. What would be a good likelihood model for comparing the rendered image and the input image?
    The paper mentions it uses some pixel wise comparison model.

  3. Might Pyro, rather than Numpyro be a better framework to do this?

  4. Is Numpyro’s tutorial for VAE a good point to start for my endeavor?

  5. Do there exist addional tutorials that may help me here?

my two cents are that you may be trying to learn too many things simultaneously (namely probabilistic programming and deep generative modeling). for example you probably want a differentiable render which pillow presumably is not. smaller learning steps may be a good idea.

Thanks for your feedback.
There are surely many things to learn here!
However, deep generative modelling and differentiable rendering might not be necessary for the paper I had in mind.

I actually tried to draft a very simple model using numpyro, below.
So far it fails with jax errors.
But I was hoping I to get some input here that may help me to step forward with my project …
(I am aware of the fact that I am a beginner here, similar to the PyMC community I just hoped I may gain some high-level insights regarding PPL here beyond pure syntax issues)

def draw_txt(letters, flatten=True):
    fontsize = 10
    
    image = Image.new("RGBA", (20,20), (255,255,255))
    draw = ImageDraw.Draw(image)
    font = ImageFont.truetype(font=r"C:/Windows/Fonts/Arial.ttf", size=fontsize)
    draw.text((0, 0), letters, (0,0,0), font=font)

    img_data = jnp.array(image)[:,:,0]/255
    
    if flatten:
        img_data = img_data.flatten(order="C")
    
    return img_data
    

def model(obs_img=img_mean, letters=None):    
    
    is_present = numpyro.sample("is_present", dist.Bernoulli(0.5), sample_shape=(1,))
    ind = jnp.where(is_present[0] == 1, 0, 1)    
                
    img_mean = draw_txt(letters[ind])
    img_cov = np.identity(img_mean.shape[0])
    
    likelihood = dist.MultivariateNormal(loc=img_mean, covariance_matrix=img_cov)
    numpyro.sample("obs", likelihood, obs=obs_img)

  nuts_kernel = NUTS(model)
  mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000)
  rng_key = random.PRNGKey(0)
  
  letters="SE"
  obs_img = draw_txt("S")
  mcmc.run(rng_key, obs_img=obs_img, letters=letters)

Because there is no continuous latent variables in your model, you can use infer_discrete for this. If you want to compile your model, there are two approaches:

  • generate a collection of image data outside of the model and just indexing it inside of the model - currently, the draw_txt function has input letters[idx] where idx might be a jax tracer.
  • use jax host_callback or pure_callback to call draw_txt.

Thanks a lot for your help!
That helped me to get a running model.
At the moment mcmc.get_samples() returns an empty dict.
I presume this has to do with your suggestion that I should use infer_discrete

I am not sure where I have to implement that. In the annotation example infer_discrete is used as an argument to get posterior predictive samples, if I got it correctly. But for this, I would need posterior samples in the first place, I assumed. There is also a decorator infer_discrete but using it with my model fails with an AssertionError, that I find hard to trace down.

Would you be so kind and have another look at my current model?

import numpy as np

import jax
import jax.numpy as jnp

import jax.experimental.host_callback as hcb
from jax import random

from PIL import Image, ImageDraw, ImageFont

import matplotlib.pyplot as plt

import numpyro
from numpyro import handlers
import numpyro.distributions as dist

from numpyro.infer import MCMC, NUTS, Predictive

from numpyro.contrib.funsor.discrete import infer_discrete
from numpyro.contrib.funsor.infer_util import config_enumerate

rng_key = random.PRNGKey(0)

from numpyro.ops.indexing import Vindex

def draw_txt(letter_ind, flatten=True):    
    
    letters = "SE"
    txt = letters[letter_ind]
    
    fontsize = 10
    
    image = Image.new("RGBA", (20,20), (255,255,255))
    draw = ImageDraw.Draw(image)
    font = ImageFont.truetype(font=r"/Library/Fonts/Arial Unicode.ttf", size=fontsize)
    draw.text((0, 0), txt, (0,0,0), font=font)

    img_data = jnp.array(image)[:,:,0]/255
    
    if flatten:
        img_data = img_data.flatten(order="C")
    
    return img_data

#@infer_discrete(rng_key=rng_key)
#@config_enumerate
def model(obs_img):    
    
    is_present = numpyro.sample("is_present", dist.Bernoulli(0.5), sample_shape=(1,))

    img_mean = hcb.call(draw_txt, Vindex(is_present)[0],
                        result_shape=jax.ShapeDtypeStruct(obs_img.shape, obs_img.dtype))
    img_cov = np.identity(img_mean.shape[0])
    
    likelihood = dist.MultivariateNormal(loc=img_mean, covariance_matrix=img_cov)
    numpyro.sample("obs", likelihood, obs=obs_img)

nuts_kernel = NUTS(model)
mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000)

obs_img = draw_txt(0)
mcmc.run(rng_key, obs_img=obs_img)

mcmc.get_samples()