Using gibbs_fn with additional parameters

Hi all, I’m trying to implement a custom Gibbs function to update the discrete variables within a custom distribution class. I want to update within gibbs_fn (which is outside of the custom class definition) that can then be used again in the log_prob function of the custom distribution. Is this possible? Thanks for your help!

can you please describe what you want in more detail? i’m afraid i don’t follow.

Sorry about that, hopefully, this clears up what I’m trying to achieve. There is a lot going on in the code outside of where this issue is, so this is a simplified pseudocode version of what I’m trying to do:

class CustomDist(dist.Distribution):
    def __init__(self, model):
        # ....

    def log_prob(self, data):
        # want to use updated bayes_model.val here

class BayesianModel(object):
    def __init__(self, val):
        self.val = val

    def model(self, data):
        # sample continuous and discrete variables
        return numpyro.sample('p', CustomDist(self), obs=data)

    def gibbs_fn(self, rng_key, gibbs_sites, hmc_sites):
        # want to update self.val here

def main():
    bayes_model = BayesianModel(val=1)
    hmc_kernel = NUTS(bayes_model.model)
    kernel = HMCGibbs(hmc_kernel, gibbs_fn=bayes_model.gibbs_fn,gibbs_sites=['i'])
    mcmc=MCMC(kernel)
    mcmc.run(random.PRNGKey(0),data)

Basically, I want to update a value in gibbs_fn that is to be used in log_prob of CustomDist. The main issue is that gibbs_fn needs attributes from bayes_model, so need to be a method of that class, but also needs to update a value not being used for inference, but to be used in the calculation of log_prob. Let me know if that doesn’t make sense.

i’m not sure i entirely follow but i imagine there should be some way to get what you want.

where is your sample statement for val? what prior is val governed by?

gibbs_fn is invoked here. as long as you use the provided rng and do a legitimate gibbs update just about anything else should be allowed

There is no sample statement for val as it is not a random variable, it is deterministic. The purpose is to activate parts of the model being evaluated in log_prob.

Is it possible to update a variable that isn’t a sample site within the gibbs_fn? I had a look through the source code you have linked, but I couldn’t see a simple way to include also extra variables not to be sampled.

given the current interface your best bet might be to make val a random variable with a trivial prior: e.g. uniform prior or use mask(False) to short circuit the log prob of some other distribution. then it will be available to gibbs_fn through the existing interface

http://num.pyro.ai/en/stable/distributions.html?highlight=mask#numpyro.distributions.distribution.Distribution.mask

Great, thanks for your help, I will try that out and hopefully it will solve the problem.