Getting inference to work with a nonparametric hierarchical model?

Hello,

I’m new to Pyro, so I’m hoping this is a case of me being “silly” and missing something. I’m trying to build a model that generates boolean matrices based on latent categories inferred from the data.

Formally the model is in a similar space as Stochastic Block Models, Cross-Cutting Models, and the Infinite Relational Model (if there’s a hidden tutorial on building these last two models do let me know!):

P(D, R, T, Z) \propto P(D | R, T, Z)P(T|R, Z)P(R|Z)P(Z)

D, R, and T are all I x J matrices. D is the observed matrix whose elements are either -1, 0, or 1. The elements in T and R are either 0 or 1. Z is an I-length vector containing the inferred cluster index for each i whose elements can be anywhere from 0 to I.

I can build the forward model and it works. Everything breaks down when I get to inference. SVI gives inf loss and nans pop up every run, leading to RuntimeErrors. MCMC+NUTS leads to changing shapes of the sample sites, which I’m not sure how to deal with in a model whose support is stochastic (Z is inferred from data via a Chinese Restaurant Process).

How can I get inference to work on this model?

I’m happy to share any information needed! I’m not sure what would be helpful at this moment so do let me know.

Below is the model code:

def model(freq_matrix):
  '''A Pyro model for the TIRM model.

  Arguments:
    freq_matrix: The frequency matrix of observations.

  Notes:
    Currently not sampling alpha.
  '''
  # Model parameters
  theta = pyro.param("theta", torch.tensor(0.5), constraint=constraints.interval(0., 1.0))
  pi = pyro.param("pi", torch.tensor(0.5), constraint=constraints.interval(0., 1.0))
  alpha = pyro.param("alpha", torch.tensor(10), constraint=constraints.positive)
  # Useful variables
  num_objs = freq_matrix.size()[0]
  num_preds = freq_matrix.size()[1]
  obj_axis = pyro.plate("obj_axis", num_objs)
  pred_axis = pyro.plate("pred_axis", num_preds)
  # p(Z), modeled by a CRP
  freqs = [] # number of customers at each table
  z = []
  for i in range(num_objs):
    probs = torch.tensor(freqs + [alpha])
    probs /= probs.sum()
    z_sample = pyro.sample(f"z_{i}", dist.Categorical(probs))
    z_item = z_sample.item()
    z.append(z_item)
    if z_item >= len(freqs):
      freqs += [1.]
    else:
      freqs[z_item] += 1.
  z = torch.tensor(z)
  assert z.shape == (num_objs,)
  num_types = len(torch.unique(z))
  # p(R | Z)
  r = torch.zeros(num_types, num_preds)
  for i in range(num_types):
    for j in range(num_preds):
      r[i,j] = pyro.sample(f"r_{i}_{j}", dist.Bernoulli(theta))
  r = F.pad(input=r, pad=(0,0,0,num_objs-r.shape[0]), mode='constant', value=0)
  assert r.shape == (num_objs,num_preds)
  t = torch.zeros(num_objs, num_preds)
  obs = torch.zeros(num_objs, num_preds)
  # Pad R since it changes depending on z
  for i in range(num_objs):
    for j in range(num_preds):
      # p (T | R, Z)
      t[i,j] = pyro.sample(f"t_{i}_{j}", dist.Bernoulli(r[z[i],j]*pi))
      # p (D | T, R, Z)
      if r[z[i],j] == 0:
        obs[i,j] = pyro.sample(f"d_{i}_{j}", dist.Delta(torch.tensor(-1)), obs=freq_matrix[i, j])
      else:
        obs[i,j] = pyro.sample(f"d_{i}_{j}", dist.Delta(t[i,j]), obs=freq_matrix[i, j])
  assert t.shape == (num_objs,num_preds)
  assert obs.shape == (num_objs,num_preds)