Hello everyone
I was wondering if there was some interface for the user to easily apply a Gibbs marginal to a subset of a sample site.
In the toy model below, we have betas
of shape x.shape[0]
. Is there an easy interface to apply a Gibbs marginal for, say, the first element of betas
?
x = jnp.arange(10)
y = jnp.arange(10)
def model(x, y = None):
with numpyro.handlers.plate("", x.shape[0]):
betas = numpyro.sample("betas", dist.Normal())
numpyro.sample("obs", dist.Normal(betas * x), obs=y)
def gibbs_fn(rng_key, gibbs_sites, hmc_sites):
return # Set betas[0] to ~ Normal(100, 0.1)
mcmc_object = MCMC(
NUTS(model), num_samples=1000, num_warmup=1000, num_chains=2, progress_bar=True
)
mcmc_object.run(random.PRNGKey(0), x=x, y=y)
In the absence of an explicit interface, we can use:
def gibbs_fn(rng_key, gibbs_sites, hmc_sites):
return {"z": dist.Normal().sample(rng_key)}
def model(x, y = None):
...
z = numpyro.sample("", dist.Normal())
new_betas = numpyro.deterministic("new_betas", betas.at[0].set(z)
...
And in this case, is there a more efficient way to construct the desired new_betas
?
Thank you