Extra sampling site in manual guide compared to model

Hi again.

What I’m actually trying to do is get a manual guide for a discrete mixture model to work, as suggested by @fehiepsi here. This is what I have, first up, the model:

def mix_weights(beta):
    """
    Function to do the stick breaking construction
    """
    beta1m_cumprod = jnp.cumprod(1 - beta, axis=-1)
    term1 = jnp.pad(beta, (0, 1), mode='constant', constant_values=1.)
    term2 = jnp.pad(beta1m_cumprod, (1, 0), mode='constant', constant_values=1.)
    return jnp.multiply(term1, term2)


def model(
    K, 
    N, 
    D_discrete,
    X_mixture=None,
    alpha=None
):  
        
    # priors
    if alpha is None:
        alpha = numpyro.sample("alpha", dist.Uniform(0.3, 10.0))
    
    with numpyro.plate("v_plate", K-1):
        v = numpyro.sample("v", dist.Beta(1, alpha))
        
    cluster_probabilities = numpyro.deterministic(
        "cluster_proba", 
        mix_weights(v), 
    )
        
    with numpyro.plate("cluster_K", K):
        _phi_latent = numpyro.sample(
            "_phi_latent",
            dist.Normal(
                loc=jnp.zeros(D_discrete * K).reshape(K, D_discrete),
                scale=jnp.ones(D_discrete * K).reshape(K, D_discrete)
            ).to_event(1)
        )
    
    phi = numpyro.deterministic("phi", jax.nn.sigmoid(_phi_latent))
        
    # model sampling
    with numpyro.plate('data', N):
        
        # Assignment is which cluster each row of data belongs to.
        assignment =  numpyro.sample(
            'assignment',
            dist.CategoricalProbs(cluster_probabilities),
            infer={'enumerate': 'parallel'}
        )

        obs = numpyro.sample(
            'obs', 
            dist.Bernoulli(phi[assignment, :]).to_event(1),
            obs=X_mixture if X_mixture is not None else None,
        )

And the guide:

def guide(
    K, 
    N, 
    D_discrete,
    X_mixture=None,
    alpha=None,
):

    n_vars = 1 + (K - 1) + D_discrete * K 

    _latent_loc = numpyro.param("_latent_loc", jnp.zeros(n_vars))
    tmp = jnp.identity(n_vars) * 0.1
    for idx in range(n_vars - 1):
        rng1 = jnp.arange(idx+1, n_vars)
        rng2 = jnp.arange(n_vars - idx - 1)
        tmp = tmp.at[rng1, rng2].set(0.01)

    _latent_L = numpyro.param(
        "_latent_L",
        tmp,
        constraint=dist.constraints.corr_cholesky,
    )
    
    _latent_distribution = numpyro.sample(
        "_latent_distribution", 
        dist.MultivariateNormal(_latent_loc, scale_tril=_latent_L),
        infer={'is_auxiliary': True}
    )

    if alpha is None:
        numpyro.sample("alpha", dist.Delta(_latent_distribution[0]))
    
    with numpyro.plate("v_plate", K-1):
        v = numpyro.sample(
            "v",
            dist.Delta(
                jax.nn.sigmoid(_latent_distribution[1:K])
            ),
        )

    cluster_probabilities = numpyro.deterministic(
        "cluster_proba",
        mix_weights(v)
    )
    
    with numpyro.plate("cluster_K", K):
        _phi_latent = numpyro.sample(
            "_phi_latent", 
            dist.Delta(_latent_distribution[K:(K + D_discrete * K)].reshape(K, D_discrete)).to_event(1)
        )
    phi = numpyro.deterministic(
        "phi",
        jax.nn.sigmoid(_phi_latent)
    )

    with numpyro.plate('data', N):
        assignment =  numpyro.sample(
            'assignment',
            dist.CategoricalProbs(cluster_probabilities)
        )

The idea is we parameterise everything as a multivariate Normal where we keep track of all the off-diagonal elements of the covariance matrix between all the model parameters for better estimation of the parameter uncertainties, and transform variables onto appropriate ranges (eg, phis are probabilities so they get hit with a sigmoid function).

I think I have set everything up correctly. It’s not completely clear which constraint is the correct one to use for a lower Cholesky matrix, but that’s the only constraint needed as the transformations take care of everything else.

Ideally, I would like this to work withTraceEnum_ELBO, but it does not work. I get a key error:

KeyError: '_latent_distribution'

which I think is because that is sampled in the guide but not the model - The documentation seems to suggest that the is_auxiliary option should make this work, but the TraceEnum_ELBO function seems to ignore this option - is this expected?

The TraceGraph_ELBO function does work with this setup, but I can’t get any good parameter estimates with this, I think for the reasons alluded to by Martin in this post.

Any suggestions?

EDIT - I’ve put this code in a google colab, in case it’s useful for any reason.