Model with a joint posterior distribution

Hi there,

Thanks for all the help so far @fehiepsi - you’ve been awesome.

I am trying to construct a joint model based on what I’ve done so far. The posterior looks like a mixture model multiplied by multilevel logistic regression model, such that there is an intercept for each cluster and the clusters are determined by the mixture model.

Here is what I have so far:

@config_enumerate
def model(
    K, 
    N, 
    D_discrete, 
    D_response,
    X_response=None,
    X_discrete=None,
    y=None
):
    
    if y is not None:
        y = jnp.array(y)
    
    cluster_proba = numpyro.sample('cluster_proba', dist.Dirichlet(0.5 * jnp.ones(K))) 
    with numpyro.plate('discrete_components', D_discrete):
        with numpyro.plate("discrete_cluster", K):
            phi = numpyro.sample('phi', dist.Beta(2.0, 2.0)) 

    with numpyro.plate("response_components", D_response):
        beta = numpyro.sample('beta', dist.Normal(0.0, 1.0))
    
    with numpyro.plate("dummy", 1): # note - without this I get errors about broadcasting
        with numpyro.plate("response_cluster", K):
            intercepts = numpyro.sample('intercepts', dist.Normal(0.0, 1.0))
    # note - if I try and remove the plates around this I get an error about a missing plate.

    with numpyro.plate('data', N):
        
        assignment = numpyro.sample('assignment', dist.CategoricalProbs(cluster_proba)) 
    
        obs1 = numpyro.sample(
            'obs1', 
            dist.Bernoulli(phi[assignment, :]).to_event(1),
            obs=X_discrete if X_discrete is not None else None,
        )
        
        linear_predictor = jnp.sum(intercepts[assignment, :] + (beta * X_response), axis=-1)
        obs2 = numpyro.sample(
            "obs2",
            dist.Bernoulli(logits=linear_predictor),
            obs=y if y is not None else None
        )
        

As you can see, there are two observations which are sampled. This code seems to work inasmuch as it runs with both MCMC and VI and the betas and phis from the fitted model are approximately correct based on some simulated input data. The problem is that the intercepts are not correct - I assume it’s because I’ve misspecified the model somehow but I can’t figure out what the issue is. Can anyone suggest anything?

Thanks!

Apologies, I’ve figured this out. The model was indeed improperly specified. I’d put the jnp.sum function around the intercepts by mistake so the linear predictor should be defined as

linear_predictor = intercepts[assignment] + jnp.sum(beta * X_response, axis=-1)

For that reason I also don’t need the dummy plate.