Extra sampling site in manual guide compared to model

Hi,

I’m trying to write a manual guide for a model. I have a 2D array of parameters that I’ve defined like this:

def model(D, E):
  
    with numpyro.plate("x1", D):
        with numpyro.plate("x2", E):
            numpyro.sample("phi", dist.Beta( 5.0, 5.0))

I’ve written a guide that uses independent Gaussians to parameterise the phi site in the model

def guide_indep(D, E):

    with numpyro.plate("x1", D):
        with numpyro.plate("x2", E):
            
            loc = numpyro.param("loc", 0.0)
            scale = numpyro.param(
                "scale",
                0.1,
                constraint=dist.constraints.positive
            )
                
            numpyro.sample(
                "phi", 
                dist.TransformedDistribution(
                    dist.Normal(
                        loc,
                        scale
                    ),
                    dist.transforms.SigmoidTransform()
                ) 
            )

this works.

I would like to write a guide that parameterises the phis as a multivariate gaussian distribution that allows for correlations between the phi parameters (so, having a covariance matrix with non-zero off diagonal elements). I can’t figure out how to get this to work, as dist.MultivariateNormal returns samples containing more than 1 element, at that messes up the array sizes inside the plates. I would like to do something like this:

def guide_MVN(D, E):


    loc = numpyro.param("loc", jnp.zeros(D * E))
    cov_mat = numpyro.param("cov_mat", (jnp.diag(jnp.ones(4) * 0.09)) + (jnp.ones((4, 4)) * 0.01)), constraint=dist.constraints.positive)
    phi_latent = numpyro.sample(
        "phi_latent",
        dist.TransformedDistribution(
            dist.MultivariateNormal(
                loc,
                covariance_matrix=cov_mat
            ),
            dist.transforms.SigmoidTransform()
        )  
    )

    with numpyro.plate("x1", D):
        with numpyro.plate("x2", E):
            numpyro.sample("phi", dist.Delta(phi_latent))

but this doesn’t work because the phi_latent site doesn’t exist in the model.

I found this thread which seems to be asking the same thing, but I can’t figure out how to make that solution work.

Would appreciate any help. Thanks :slight_smile:

Hi again.

What I’m actually trying to do is get a manual guide for a discrete mixture model to work, as suggested by @fehiepsi here. This is what I have, first up, the model:

def mix_weights(beta):
    """
    Function to do the stick breaking construction
    """
    beta1m_cumprod = jnp.cumprod(1 - beta, axis=-1)
    term1 = jnp.pad(beta, (0, 1), mode='constant', constant_values=1.)
    term2 = jnp.pad(beta1m_cumprod, (1, 0), mode='constant', constant_values=1.)
    return jnp.multiply(term1, term2)


def model(
    K, 
    N, 
    D_discrete,
    X_mixture=None,
    alpha=None
):  
        
    # priors
    if alpha is None:
        alpha = numpyro.sample("alpha", dist.Uniform(0.3, 10.0))
    
    with numpyro.plate("v_plate", K-1):
        v = numpyro.sample("v", dist.Beta(1, alpha))
        
    cluster_probabilities = numpyro.deterministic(
        "cluster_proba", 
        mix_weights(v), 
    )
        
    with numpyro.plate("cluster_K", K):
        _phi_latent = numpyro.sample(
            "_phi_latent",
            dist.Normal(
                loc=jnp.zeros(D_discrete * K).reshape(K, D_discrete),
                scale=jnp.ones(D_discrete * K).reshape(K, D_discrete)
            ).to_event(1)
        )
    
    phi = numpyro.deterministic("phi", jax.nn.sigmoid(_phi_latent))
        
    # model sampling
    with numpyro.plate('data', N):
        
        # Assignment is which cluster each row of data belongs to.
        assignment =  numpyro.sample(
            'assignment',
            dist.CategoricalProbs(cluster_probabilities),
            infer={'enumerate': 'parallel'}
        )

        obs = numpyro.sample(
            'obs', 
            dist.Bernoulli(phi[assignment, :]).to_event(1),
            obs=X_mixture if X_mixture is not None else None,
        )

And the guide:

def guide(
    K, 
    N, 
    D_discrete,
    X_mixture=None,
    alpha=None,
):

    n_vars = 1 + (K - 1) + D_discrete * K 

    _latent_loc = numpyro.param("_latent_loc", jnp.zeros(n_vars))
    tmp = jnp.identity(n_vars) * 0.1
    for idx in range(n_vars - 1):
        rng1 = jnp.arange(idx+1, n_vars)
        rng2 = jnp.arange(n_vars - idx - 1)
        tmp = tmp.at[rng1, rng2].set(0.01)

    _latent_L = numpyro.param(
        "_latent_L",
        tmp,
        constraint=dist.constraints.corr_cholesky,
    )
    
    _latent_distribution = numpyro.sample(
        "_latent_distribution", 
        dist.MultivariateNormal(_latent_loc, scale_tril=_latent_L),
        infer={'is_auxiliary': True}
    )

    if alpha is None:
        numpyro.sample("alpha", dist.Delta(_latent_distribution[0]))
    
    with numpyro.plate("v_plate", K-1):
        v = numpyro.sample(
            "v",
            dist.Delta(
                jax.nn.sigmoid(_latent_distribution[1:K])
            ),
        )

    cluster_probabilities = numpyro.deterministic(
        "cluster_proba",
        mix_weights(v)
    )
    
    with numpyro.plate("cluster_K", K):
        _phi_latent = numpyro.sample(
            "_phi_latent", 
            dist.Delta(_latent_distribution[K:(K + D_discrete * K)].reshape(K, D_discrete)).to_event(1)
        )
    phi = numpyro.deterministic(
        "phi",
        jax.nn.sigmoid(_phi_latent)
    )

    with numpyro.plate('data', N):
        assignment =  numpyro.sample(
            'assignment',
            dist.CategoricalProbs(cluster_probabilities)
        )

The idea is we parameterise everything as a multivariate Normal where we keep track of all the off-diagonal elements of the covariance matrix between all the model parameters for better estimation of the parameter uncertainties, and transform variables onto appropriate ranges (eg, phis are probabilities so they get hit with a sigmoid function).

I think I have set everything up correctly. It’s not completely clear which constraint is the correct one to use for a lower Cholesky matrix, but that’s the only constraint needed as the transformations take care of everything else.

Ideally, I would like this to work withTraceEnum_ELBO, but it does not work. I get a key error:

KeyError: '_latent_distribution'

which I think is because that is sampled in the guide but not the model - The documentation seems to suggest that the is_auxiliary option should make this work, but the TraceEnum_ELBO function seems to ignore this option - is this expected?

The TraceGraph_ELBO function does work with this setup, but I can’t get any good parameter estimates with this, I think for the reasons alluded to by Martin in this post.

Any suggestions?

EDIT - I’ve put this code in a google colab, in case it’s useful for any reason.

if i understand it right you should be able to use block together with vanilla guides like AutoMultivariateNormal? e.g. on the pyro side => pyro/examples/capture_recapture/cjs.py at dev · pyro-ppl/pyro · GitHub

Thanks Martin. This works using TraceGraph_ELBO but not TraceEnum_ELBO - I get this error:

KeyError: 'v_plate'

For completeness, this is what I’m doing in numpyro:

rng = jax.random.PRNGKey(0)

guide = numpyro.infer.autoguide.AutoMultivariateNormal(
    numpyro.handlers.block(
        numpyro.handlers.seed(model, rng), 
        lambda site: site["name"] == "assignment"
    )
)