Hello,
I’m new to pyro/numpyro and trying to wrap my head around various details, including but not limited to shapes.
I’m getting the below error for my model:
ValueError: Missing a plate statement for batch dimension -2 at site 'r'. You can use `numpyro.util.format_shapes` utility to check shapes at all sites of your model.
But, when outputting the variable shapes via util.format_shapes I don’t see the issue:
Z shape: (5,)
R shape: (5, 5)
T shape: (5, 5)
D shape: (5, 5)
Trace Shapes:
Param Sites:
theta
pi
alpha
Sample Sites:
beta_plate plate 4 |
beta dist 4 |
value 4 |
z_plate plate 5 |
z dist 5 |
value 5 |
r dist 5 5 |
value 5 5 |
t dist 5 5 |
value 5 5 |
d dist 5 5 |
value 5 5 |
Here is my model:
def model(freq_matrix):
'''A Pyro model for the TIRM model.
Arguments:
freq_matrix: The frequency matrix of observations.
Notes:
Currently not sampling alpha.
'''
# Model parameters
theta = numpyro.param("theta", jnp.array(0.5), constraint=constraints.interval(0., 1.0))
pi = numpyro.param("pi", jnp.array(0.5), constraint=constraints.interval(0., 1.0))
alpha = numpyro.param("alpha", jnp.array(10), constraint=constraints.positive)
# Useful variables
num_objs = freq_matrix.shape[0]
num_preds = freq_matrix.shape[1]
# Mixing weights for p(z)
with numpyro.plate("beta_plate", num_objs-1):
beta = numpyro.sample("beta", dist.Beta(1, alpha))
# p(z)
with numpyro.plate("z_plate", num_objs):
mix = mix_weights(beta)
z = numpyro.sample("z", dist.Categorical(mix), infer={"enumerate": "parallel"})
print(f"Z shape: {z.shape}")
# p(r | z)
r = numpyro.sample("r", dist.Bernoulli(jnp.ones((num_objs,num_preds))*theta))
print(f"R shape: {r.shape}")
# p(t | r, z)
t = numpyro.sample("t", dist.Bernoulli(r[z]*pi))
print(f"T shape: {t.shape}")
# p(d | r, t, z)
d = numpyro.sample(f"d", dist.Delta(t + r[z] - 1), obs=freq_matrix)
print(f"D shape: {d.shape}")
I’m sure this is something silly I’m doing, but I can’t seem to figure it out yet.
Any help is appreciated!