Hi,
I’m new to NumPyro and I’m struggling with getting the sampling to work for my task. I’m sampling a dataset of N elements, where each element is a pair of points p = (x0, x1) with the constraint that x1 > x0 and x0, x1 > 0. x0 and x1 are each sampled from a mixture model of two clusters. The cluster means/stdevs for x0 and x1 are different, but the cluster assignment is the same. I have written the following code in numpyro for this:
def positive_ordered_model(N, pi, x0_prior, x1_prior):
with numpyro.plate("N", N):
#i = cluster assignment
i = numpyro.sample("i", dist.Categorical(jnp.array(pi)), infer={'enumerate': 'parallel'})
print("i.shape: {}".format(i.shape))
mean_x0 = x0_prior[i, 0]
mean_x1 = x1_prior[i, 0]
std_x0 = x0_prior[i, 1]
std_x1 = x1_prior[i, 1]
print("mean_x0.shape: {}, mean_x1.shape: {}".format(mean_x0.shape, mean_x1.shape))#, jnp.array([x0_prior[i, 0], x1_prior[i, 0]]).T.shape)
print("std_x0.shape: {}, std_x1.shape: {}".format(std_x1.shape, std_x1.shape))#print(x0_prior[i, 1].shape, x1_prior[i, 1].shape, jnp.array([x0_prior[i, 1], x1_prior[i, 1]]).T.shape)
mean = jnp.array([mean_x0, mean_x1]).T
std = jnp.array([std_x0, std_x1]).T
p = numpyro.sample("p", dist.TransformedDistribution(
dist.Normal(mean, std),
transforms.ComposeTransform([transforms.OrderedTransform(), transforms.ExpTransform()])
))
print("p.shape: {}".format(p.shape))
p0 = numpyro.deterministic("p0", p[:, 0])
p1 = numpyro.deterministic("p1", p[:, 1])
The following is my code for sampling
sampler = infer.MCMC(
infer.NUTS(positive_ordered_model),
num_warmup=500,
num_samples=500,
num_chains=2,
progress_bar=True
)
jrng_key = jax.random.PRNGKey(42)
N = 100,
pi = jnp.array([0.5, 0.5])
x0_prior = jnp.array([[2, 0.5], [3, 0.5]])
x1_prior = jnp.array([[4, 0.5], [5, 0.5]])
sampler.run(
jrng_key,
N,
pi,
x0_prior,
x1_prior
)
When I running I get this error at the end of the stack trace
ValueError: Incompatible shapes for broadcasting: ((1, 100), (1, 2))
Before the error, the following print statements from the model definition are displayed
i.shape: (100,)
mean_x0.shape: (100,), mean_x1.shape: (100,)
std_x0.shape: (100,), std_x1.shape: (100,)
p.shape: (100, 2)
i.shape: (2, 1)
mean_x0.shape: (2, 1), mean_x1.shape: (2, 1)
std_x0.shape: (2, 1), std_x1.shape: (2, 1)
This suggests that the cluster assignment vector is being assigned the wrong shape. I’m not sure why.
I also tried to remote the sampling statement of p = numpyro.sample(...)
and the rest of the code out of the plate context. But that gives me the following error.
ValueError: Missing a plate statement for batch dimension -1 at site 'p'. You can use `numpyro.util.format_shapes` utility to check shapes at all sites of your model.
Any help appreciated. Thanks in advance.