Missing plate statement on batch dimension?

Hello,

I’m new to pyro/numpyro and trying to wrap my head around various details, including but not limited to shapes.

I’m getting the below error for my model:

ValueError: Missing a plate statement for batch dimension -2 at site 'r'. You can use `numpyro.util.format_shapes` utility to check shapes at all sites of your model.

But, when outputting the variable shapes via util.format_shapes I don’t see the issue:

Z shape: (5,)
R shape: (5, 5)
T shape: (5, 5)
D shape: (5, 5)
   Trace Shapes:      
    Param Sites:      
           theta      
              pi      
           alpha      
   Sample Sites:      
beta_plate plate   4 |
       beta dist   4 |
           value   4 |
   z_plate plate   5 |
          z dist   5 |
           value   5 |
          r dist 5 5 |
           value 5 5 |
          t dist 5 5 |
           value 5 5 |
          d dist 5 5 |
           value 5 5 |

Here is my model:

def model(freq_matrix):
  '''A Pyro model for the TIRM model.

  Arguments:
    freq_matrix: The frequency matrix of observations.

  Notes:
    Currently not sampling alpha.
  '''
  # Model parameters
  theta = numpyro.param("theta", jnp.array(0.5), constraint=constraints.interval(0., 1.0))
  pi = numpyro.param("pi", jnp.array(0.5), constraint=constraints.interval(0., 1.0))
  alpha = numpyro.param("alpha", jnp.array(10), constraint=constraints.positive)
  # Useful variables
  num_objs = freq_matrix.shape[0]
  num_preds = freq_matrix.shape[1]
  # Mixing weights for p(z)
  with numpyro.plate("beta_plate", num_objs-1):
    beta = numpyro.sample("beta", dist.Beta(1, alpha))
  # p(z)
  with numpyro.plate("z_plate", num_objs):
    mix = mix_weights(beta)
    z = numpyro.sample("z", dist.Categorical(mix), infer={"enumerate": "parallel"}) 
    print(f"Z shape: {z.shape}")
    # p(r | z)
    r = numpyro.sample("r", dist.Bernoulli(jnp.ones((num_objs,num_preds))*theta))
    print(f"R shape: {r.shape}")
    # p(t | r, z)
    t = numpyro.sample("t", dist.Bernoulli(r[z]*pi))
    print(f"T shape: {t.shape}")
    # p(d | r, t, z)
    d = numpyro.sample(f"d", dist.Delta(t + r[z] - 1), obs=freq_matrix)
    print(f"D shape: {d.shape}")

I’m sure this is something silly I’m doing, but I can’t seem to figure it out yet.

Any help is appreciated!

You have to state which dimension is conditionally independent. You should use to_event. Check this: https://ericmjl.github.io/blog/2019/5/29/reasoning-about-shapes-and-probability-distributions/