Using TransformReparam with Gamma distribution?

Hi so I tried this BTYD model below and I got 879 divergences!

So I tried to do use the reparam config with LocScaleReparam(0) but that would not work on the Gamma and Inverse Gamma functions.

from numpyro.handlers import reparam
from numpyro.infer.reparam import LocScaleReparam
reparam_config = {
    "etau": LocScaleReparam(0.),
    "Lambda": LocScaleReparam(0.)
}
reparam_model = reparam(model, config=reparam_config)

AttributeError: 'InverseGamma' object has no attribute 'loc'

I’d guess I need to use TransformReparam? But I am not sure how to formulate with concentration and rate distributions like Gamma and Inverse Gamma…
Can anyone help? Thanks!

def model_mixed_priors(t, T, k, prior_only=True):
  '''
  def loglik(Lambda, mu, t, T, k):
    target = k * jnp.log(Lambda) - jnp.log(Lambda + mu)
    n = t.size
    for i in range(n): #this looks like a vector reduce calculation.
      target  = target + jnp.logaddexp(jnp.log(Lambda[i]) - (Lambda[i] + mu[i]) * T[i],
                                        jnp.log(mu[i]) - (Lambda[i] + mu[i]) * t[i]
                                        )
    return target
  
  etau_alpha = numpyro.sample('etau_alpha', dist.Uniform(0.01, 2))
  etau_beta = numpyro.sample('etau_beta', dist.HalfNormal(scale=200))
  Lambda_alpha = numpyro.sample('Lambda_alpha', dist.Uniform(0.01, 2))
  Lambda_beta = numpyro.sample('Lambda_beta', dist.HalfNormal(scale=200))

  if not prior_only:
    with numpyro.plate("data", t.size):
      etau  = numpyro.sample('etau', dist.InverseGamma(etau_alpha, etau_beta)) #mean lifetime
      mu = numpyro.deterministic('mu', 1./etau)
      Lambda = numpyro.sample('Lambda', dist.Gamma(Lambda_alpha, Lambda_beta))
      one_over_Lambda = numpyro.deterministic('one_over_Lambda', 1./Lambda)
    numpyro.factor('loglik', loglik(Lambda, mu, t, T, k))


model = model_mixed_priors
nuts_kernel = NUTS(model)
mcmc = MCMC(nuts_kernel, num_warmup=250, num_samples=1000)
rng_key = random.PRNGKey(0)
mcmc.run(rng_key, t, T, k, prior_only=False)
mcmc.print_summary() #879 divergences!!!

ah, nevermind, I just switched to lognormal and will use the reparam on that distribution.

1 Like