Sampling fails when using positive ordered constraints on normally distributed variable that is dependent on a discrete variable

Hi,

I’m new to NumPyro and I’m struggling with getting the sampling to work for my task. I’m sampling a dataset of N elements, where each element is a pair of points p = (x0, x1) with the constraint that x1 > x0 and x0, x1 > 0. x0 and x1 are each sampled from a mixture model of two clusters. The cluster means/stdevs for x0 and x1 are different, but the cluster assignment is the same. I have written the following code in numpyro for this:

def positive_ordered_model(N, pi, x0_prior, x1_prior):
    with numpyro.plate("N", N):
        #i = cluster assignment
        i = numpyro.sample("i", dist.Categorical(jnp.array(pi)), infer={'enumerate': 'parallel'})
        print("i.shape: {}".format(i.shape))
        
        mean_x0 = x0_prior[i, 0]
        mean_x1 = x1_prior[i, 0]
        
        std_x0 = x0_prior[i, 1]
        std_x1 = x1_prior[i, 1]
        
        print("mean_x0.shape: {}, mean_x1.shape: {}".format(mean_x0.shape, mean_x1.shape))#, jnp.array([x0_prior[i, 0], x1_prior[i, 0]]).T.shape)
        print("std_x0.shape: {}, std_x1.shape: {}".format(std_x1.shape, std_x1.shape))#print(x0_prior[i, 1].shape, x1_prior[i, 1].shape, jnp.array([x0_prior[i, 1], x1_prior[i, 1]]).T.shape)
        
        mean = jnp.array([mean_x0, mean_x1]).T
        std = jnp.array([std_x0, std_x1]).T
        
        p = numpyro.sample("p", dist.TransformedDistribution(
            dist.Normal(mean, std), 
            transforms.ComposeTransform([transforms.OrderedTransform(), transforms.ExpTransform()])
        ))

        print("p.shape: {}".format(p.shape))
        
        p0 = numpyro.deterministic("p0", p[:, 0])
        p1 = numpyro.deterministic("p1", p[:, 1])

The following is my code for sampling

sampler = infer.MCMC(
    infer.NUTS(positive_ordered_model),
    num_warmup=500,
    num_samples=500,
    num_chains=2,
    progress_bar=True
)

jrng_key = jax.random.PRNGKey(42)
N = 100,
pi = jnp.array([0.5, 0.5])
x0_prior = jnp.array([[2, 0.5], [3, 0.5]])
x1_prior = jnp.array([[4, 0.5], [5, 0.5]])

sampler.run(
     jrng_key, 
     N, 
     pi,
     x0_prior,
     x1_prior
)

When I running I get this error at the end of the stack trace

ValueError: Incompatible shapes for broadcasting: ((1, 100), (1, 2))

Before the error, the following print statements from the model definition are displayed

i.shape: (100,)
mean_x0.shape: (100,), mean_x1.shape: (100,)
std_x0.shape: (100,), std_x1.shape: (100,)
p.shape: (100, 2)
i.shape: (2, 1)
mean_x0.shape: (2, 1), mean_x1.shape: (2, 1)
std_x0.shape: (2, 1), std_x1.shape: (2, 1)

This suggests that the cluster assignment vector is being assigned the wrong shape. I’m not sure why.

I also tried to remote the sampling statement of p = numpyro.sample(...) and the rest of the code out of the plate context. But that gives me the following error.

ValueError: Missing a plate statement for batch dimension -1 at site 'p'. You can use `numpyro.util.format_shapes` utility to check shapes at all sites of your model.

Any help appreciated. Thanks in advance.

Hi @pankajb64 under enumeration, the operators need to accommodate batching. For example, operators like x.T does not work when x has multiple dimensions, or x.squeeze() will destroy all singleton batch dimensions. The cluster assignment vector has expected shape (2, 1) under enumerator (upon broadcasting, this is (2, 100) where 2 is the size of the support of i).

Thanks for the reply @fehiepsi. I’m not sure what you mean by "size of the support of i".

Instead of using .T, I am using jnp.dstack and that seems to work. Is this the right way to do it? If not, can you share what the right way to solve this problem is?

The model code now looks like below (after the dstack changes).

def positive_ordered_model(N, pi, x0_prior, x1_prior):
    with numpyro.plate("N", N):
        #i = cluster assignment
        i = numpyro.sample("i", dist.Categorical(jnp.array(pi)), infer={'enumerate': 'parallel'})
        
        mean_x0 = x0_prior[i, 0]
        mean_x1 = x1_prior[i, 0]
        
        std_x0 = x0_prior[i, 1]
        std_x1 = x1_prior[i, 1]
                
        mean = jnp.dstack([mean_x0, mean_x1])
        std = jnp.dstack([std_x0, std_x1])
        
       
        p = numpyro.sample("p", dist.TransformedDistribution(
            dist.Normal(mean, std), 
            transforms.ComposeTransform([transforms.OrderedTransform(), transforms.ExpTransform()])
        ))

        #take along last axis, deal with batching and variable number of dimensions
        p0 = numpyro.deterministic("p0", p.take(0, axis=-1))
        p1 = numpyro.deterministic("p1", p.take(1, axis=-1))

Because dstack stacks arrays in sequence depth wise (along third axis), it does not work for batched inputs. Why not just use jnp.stack([x0, x1], -1) if you want to stack along the last axis?

size of the support of i

Supports are possible values of your distributions. Your category variable has support [0, 1], so its size is 2. If you have something like Categorical(logits=jnp.zeros(10)), the size of your support will be 10. We will enumerate over those 10 values during enumeration.

Thanks @fehiepsi, I’ll use jnp.stack(..., -1).

I have another very related issue. I am running the following model, which is also failing due to array size mismatch, which is very likely a broadcasting error.

Here, I’m sampling a point x0, from a mixture model determined by the cluster assignment i, which I use to compute a deterministic variable p1 = intercept + x0*slope, where slope and intercept are also dependent on i. Additionally, I create another deterministic variable p0 = some_data_variable[i] which also depends on i in a simpler way. While p0 seems to respect batching, p1 does not.

def positive_ordered_model(N, pi, x0_prior, x1_prior, x_intercept, x_slope):
    with numpyro.plate("N", N):
        #i = cluster assignment
        i = numpyro.sample("i", dist.Categorical(jnp.array(pi)), infer={'enumerate': 'parallel'})
        print("i.shape: {}".format(i.shape))
        
        x0 = numpyro.sample("x0", dist.Normal(loc=x0_prior[i, 0], scale=x0_prior[i, 1]))
        
        p0 = numpyro.deterministic("p0", x1_prior[i])
        p1 = numpyro.deterministic("p1", x_intercept[i] + x0*x_slope[i])
        
        print("p0.shape: {}, p1.shape: {}".format(p0.shape, p1.shape))
        p = numpyro.deterministic("p", jnp.stack([p0, p1], axis=-1))

When I run it using the following code

N = 100
pi = jnp.array([0.5, 0.5])
x0_prior = jnp.array([[2, 0.5], [3, 0.5]])
x1_prior = jnp.array([0, 1])
x_intercept = jnp.array([10, 20])
x_slope = jnp.array([0.8, 0.4])

sampler = infer.MCMC(
    infer.NUTS(positive_ordered_model),
    num_warmup=500,
    num_samples=500,
    num_chains=2,
    progress_bar=True
)

jrng_key = jax.random.PRNGKey(42)

sampler.run(
     jrng_key, 
     N, 
     pi,
     x0_prior,
     x1_prior,
     x_intercept,
     x_slope
)

I get this error at the line p = numpyro.deterministic("p", jnp.stack([p0, p1], axis=-1))

ValueError: All input arrays must have the same shape.

And the following statements are printed before the error

i.shape: (100,)
p0.shape: (100,), p1.shape: (100,)
i.shape: (2, 1)
p0.shape: (2, 1), p1.shape: (2, 100)

Not sure why p1 has a broadcasted shape of (2, 100), I expected it to be (2, 1).

This is a limitation of jnp.stack (probably to match numpy behavior). I think you can broadcast p0 and p1 to have the same shape before stacking. Kind of

shape = jax.lax.broadcast_shapes(p0.shape, p1.shape)
p0, p1 = jnp.broadcast_to(p0, shape), jnp.broadcast_to(p1, shape)
1 Like

Thanks, that seemed to resolve the error.

I now want to use vectors p to sample other variables which are in fact observed. This seems to cause a problem - the cluster assignment i now retuns a single variable, while it is supposed to be a vector of size N (the length of the plate in which it is sampled).

The new model is below

def f(p, t):
    return 2*jnp.tanh(p[..., 0] - t) - jnp.tanh(p[..., 1] - 2*t)

def positive_ordered_model(N, pi, x0_prior, x1_prior, x_intercept, x_slope, M, N_to_M_mapping, t, values_obs, sigma):
    with numpyro.plate("N", N):
        #i = cluster assignment
        i = numpyro.sample("i", dist.Categorical(jnp.array(pi)), infer={'enumerate': 'parallel'})
        print("i.shape: {}".format(i.shape))
        
        x0 = numpyro.sample("x0", dist.Normal(loc=x0_prior[i, 0], scale=x0_prior[i, 1]))
        
        p0 = numpyro.deterministic("p0", x1_prior[i])
        p1 = numpyro.deterministic("p1", x_intercept[i] + x0*x_slope[i])
        
        shape = jax.lax.broadcast_shapes(p0.shape, p1.shape)
        p0, p1 = jnp.broadcast_to(p0, shape), jnp.broadcast_to(p1, shape)

        print("p0.shape: {}, p1.shape: {}".format(p0.shape, p1.shape))
        p = numpyro.deterministic("p", jnp.stack([p0, p1], axis=-1))
    
    with numpyro.plate("M", M):
        #points `p` are parameter values to a function which needs to be evaluated at certain given times
        #function values at these times are observed (with some noise)
        #for each point, the function maybe observed at different times and different number of times. len(t_point1) != len(t_point2)
        #Hence this plate is not nested inside in the "N" plate.
        #vector `t` is a flattened vector of the times at which each point was observed, len(t) = M = len(t_point1) + len(t_point2) + ... + len(t_pointN)
        #vector `values_obs` is similarly a flattened vector of the corresponding noisy observed values.
        #the mapping of which value `V_m` is for which point `p_n` is given by `N_to_M_mapping`
        ##This could perhaps also be implemented as a scan over function which creates a plate of length len(t_pointn) for point n.
        p_input = p[..., N_to_M_mapping, :]
        values = numpyro.sample("values", dist.Normal(f(p_input, t), sigma), obs=values_obs)

The data and the code to run is in this notebook - Cluster Assignment sampled as single variable instead of vector

When I run this, I get i as an array of shape (#chains, #draws) instead of the expected (#chains, #draws, N). What am I doing wrong?

It seems that your model does not satisfy this restriction. If so, you can disable enumeration and use other methods like DiscreteHMCGibbs to infer discrete values (though it might not work well for moderate to large N).

I would suggest that you create a new thread (because the topic is irrelevant) and simplify the model to just include a few lines of code (to illustrate the main ideas).