Implementing Gibbs on a subset of a plated site

Hello everyone

I was wondering if there was some interface for the user to easily apply a Gibbs marginal to a subset of a sample site.

In the toy model below, we have betas of shape x.shape[0]. Is there an easy interface to apply a Gibbs marginal for, say, the first element of betas?

x = jnp.arange(10)
y = jnp.arange(10)
def model(x, y = None):
    with numpyro.handlers.plate("", x.shape[0]):
        betas = numpyro.sample("betas", dist.Normal())
    numpyro.sample("obs", dist.Normal(betas * x), obs=y)
def gibbs_fn(rng_key, gibbs_sites, hmc_sites):
    return # Set betas[0] to ~ Normal(100, 0.1)

mcmc_object = MCMC(
     NUTS(model), num_samples=1000, num_warmup=1000, num_chains=2, progress_bar=True
)
mcmc_object.run(random.PRNGKey(0), x=x, y=y)

In the absence of an explicit interface, we can use:

def gibbs_fn(rng_key, gibbs_sites, hmc_sites):
    return {"z": dist.Normal().sample(rng_key)}

def model(x, y = None):
    ...
    z = numpyro.sample("", dist.Normal())
    new_betas = numpyro.deterministic("new_betas", betas.at[0].set(z)
    ...

And in this case, is there a more efficient way to construct the desired new_betas?

Thank you

not sure what you’re asking exactly but gibbs_fn can do any valid gibbs update, including updating a single element of gibbs_sites['betas']

Thanks for replying @martinjankowiak
Specifically, how would I update the first element of gibbs_sites['betas']? The interface seems to only support dictionaries to update all the elements of the sample site.

still use the full dictionary and return all sites but you needn’t modify every element in every tensor.

old_betas = gibbs_sites['betas']
new_betas = old_betas.at[0].set(some_custom_function(hmc_sites, rng_key, ...)
return {"betas": new_betas}

Solved, thank you.