Hey there!
I have a quite simple hierarchical (num)pyro model: I am looking at a number of n machine parts for which I want to estimate their mean time between failures. For that, I am using a Weibull distribution. Since the machine parts are independent from each other, they get their own likelihood, however I want them to share a common prior.
Now, I have built up the model (+ sample data)
import numpyro
import numpyro.infer as infer
from numpyro.distributions import Weibull, HalfCauchy, Uniform
import jax.random
key = jax.random.PRNGKey(123)
alpha_true = 2
beta_true = 3
N = 55
data = jax.random.weibull_min(key, scale=alpha_true, concentration=beta_true, shape=(N,))
alpha_true_2 = 6
beta_true_2 = 5
N = 100
data2 = jax.random.weibull_min(key, scale=alpha_true_2, concentration=beta_true_2, shape=(N,))
alpha_true_3 = 9
beta_true_3 = 12
N = 1
data3 = jax.random.weibull_min(key, scale=alpha_true_3, concentration=beta_true_3, shape=(N,))
data_final = [data, data2, data3]
def model(data):
uniform_data = (0.5, 5)
common_prior_scale = numpyro.sample("common_scale", Uniform(*uniform_data))
common_prior_shape = numpyro.sample("common_shape", Uniform(*uniform_data))
for key, group in enumerate(data):
scale = numpyro.sample(f"scale_{key}", HalfCauchy(common_prior_scale))
shape = numpyro.sample(f"shape_{key}", HalfCauchy(common_prior_shape))
numpyro.sample(f"y_{key}", Weibull(scale, shape).expand([len(group)]), obs=group)
nuts_kernel = infer.NUTS(model)
mcmc = infer.MCMC(nuts_kernel, num_samples=3000, num_warmup=1000)
mcmc.run(key, data_final)
# Summarize the results
mcmc.print_summary()
With numpyro, this is already quite fast, however these are just dummy data and I am expecting more data (and machine parts) to come. Of course, I might use SVI but this will come later. I thought that it must be possible to vectorize the whole thing and use a plate in order to model the conditional independence between these groups, however I am completely lost in implementing this. Any help would be appreciated.
Thanks and best regards,
m