Handle 2d plate

Hi all,
I’m a new Pyro user and I’m trying to handle a simple (in theory) problem:
I have a tensor 500x300 where each row is an independent sample from a Poisson distribution where the lambda is generated by a gamma distribution (Gamma-Poisson model).

import torch
import pyro
import math
import matplotlib.pyplot as plt
import numpy as np
import pyro.optim
from pyro.infer import SVI, Trace_ELBO, MCMC, NUTS
import pyro.distributions as dist
from torch.distributions import constraints

back_dist = np.random.gamma(5,1.2,500)

dt = np.zeros((500,300))
for i in range(500):
  dt[i] = np.random.poisson(back_dist[i], 300)

X = torch.tensor(dt, dtype=torch.float32)

I would estimate the parameters of the gamma distribution (“latent rate”, common for each row) and the lambda specific for each row. Here what I’m trying to do, I’m sure that I’m making mistakes with the plates

def gamma_poisson_model(data):
    #N = 250
    N, D = data.shape
    # define the hyperparameters that control the Gamma prior
    alpha0 = torch.tensor(2.0)
    beta0 = torch.tensor(2.0)
    # sample f from the Beta prior
    f = pyro.sample("latent_rate", dist.Gamma(alpha0,beta0))
    # loop over the observed data
    with pyro.plate('x_data', N):
      with pyro.plate('y_data', D):
        pyro.sample('obs', dist.Poisson(f), obs=data)    

def guide(data):
    # register the two variational parameters with Pyro.
    alpha_q = pyro.param("alpha_q", torch.tensor(2.0),
    beta_q = pyro.param("beta_q", torch.tensor(2.0),
    # sample latent rate from the distribution Gamma(alpha_q, beta_q)
    pyro.sample("latent_rate", dist.Gamma(alpha_q, beta_q))

Do you have any advice to solve the problem? I read that I should use “to_event” to declare each independent row, but it’s not clear how to use it.

Thank you for your help

plates go from right to left in terms of tensor dimensions and out-to-in in terms of python context managers. so i think you want

with pyro.plate('y_data', D):
    with pyro.plate('x_data', N):

Hi, I think you can also change the order of the tensor dimensions by writing

with pyro.plate('x_data', N, dim=-2):
    with pyro.plate('y_data', D, dim=-1):
        pyro.sample('obs', dist.Poisson(f), obs=data)

Or, you could transpose the data matrix (X = X.T) and reverse the dimensions (D, N = data.shape).

However, I think your model has other problems: At the moment, you are not modeling the individual lambdas, but only a single rate parameter that is common for each element of the data matrix. Is the goal also to estimate the parameters of the gamma distribution from which the lambdas are sampled? I tried “fixing” it, and came up with the following:

def gamma_poisson_model(data):
    D, N = data.shape
    # register the hyperparameters that control the Gamma prior
    alpha = pyro.param("alpha", torch.tensor(2.0), constraint=constraints.positive)
    beta = pyro.param("beta", torch.tensor(2.0), constraint=constraints.positive)
    # loop over the observed data
    with pyro.plate('x_data', N):
        # sample the rate parameter for D i.i.d. observations
        lam = pyro.sample("lambda", dist.Gamma(alpha, beta))
        with pyro.plate('y_data', D):
            pyro.sample('obs', dist.Poisson(lam), obs=data)    

def guide(data):
    D, N = data.shape
    # register variational parameters for lambda
    lam_q = pyro.param("lambda_q", torch.full((N,), 2.0))
    with pyro.plate('x_data', N):
        # sample rate parameter lambda
        lam = pyro.sample("lambda", dist.Delta(lam_q))

this just produces MAP estimates of the parameters, and alpha and beta don’t even have a prior. But the point is that the pyro.sample call for lambda is within the first plate context. Note that when you reverse the order of the dims with dim=-2 and dim=-1, you have to make sure that lam_q has the right shape (N,1) instead of (N,), because otherwise lam will be a matrix instead of a vector.

Disclaimer: I am a pyro newbie myself, so I hope the pyro experts will correct me if I made a mistake here.

1 Like

Thank you!
@martinjankowiak my solution was transposing the data → data.T

@chvandorp Thank you for the advice, yes I was working on something similar to your solution, thank you so much for your time. I was trying to make a comparison between Pyro and a previous STAN code to solve the same task.