I have tried a few other search terms and have found a solution in another thread from which I have adapted the code. Does this make sense?
import jax.numpy as jnp
import numpyro
from numpyro.infer import MCMC, NUTS
import numpy as np
import time
import torch
import numpyro
from numpyro.distributions import Weibull, HalfCauchy, Uniform
import jax.random
# Generate Data
key = jax.random.PRNGKey(123)
alpha_true = 2
beta_true = 3
N1 = 55
data = jax.random.weibull_min(key, scale=alpha_true, concentration=beta_true, shape=(N1,))
alpha_true_2 = 6
beta_true_2 = 5
N2 = 100
data2 = jax.random.weibull_min(key, scale=alpha_true_2, concentration=beta_true_2, shape=(N2,))
alpha_true_3 = 9
beta_true_3 = 12
N3 = 1
data3 = jax.random.weibull_min(key, scale=alpha_true_3, concentration=beta_true_3, shape=(N3,))
data_list = [data, data2, data3]
sample_sizes = np.array([N1, N2, N3])
N_Max = max(sample_sizes)
dl = []
for n, item in zip(sample_sizes, data_list):
pads = jnp.ones((N_Max-n)) * .01
item = jnp.concatenate([item, pads])
dl.append(item)
# Combine samples
data_final = jnp.array(dl, order='K').T
def model(sample_sizes, data = None):
# Define the number of groups from which we've attained samples
n_groups = sample_sizes.shape[0]
uniform_data = (0.5, 5)
common_prior_scale = numpyro.sample("common_scale", Uniform(*uniform_data))
common_prior_shape = numpyro.sample("common_shape", Uniform(*uniform_data))
# Initialize a plate for each group
with numpyro.plate("n_groups", n_groups):
# Set priors within each group
scale = numpyro.sample(f"scale", HalfCauchy(common_prior_scale))
shape = numpyro.sample(f"shape", HalfCauchy(common_prior_shape))
# Calculate maximum group size
I = sample_sizes.max().item()
# Create a range of 0:I
i = torch.arange(I).unsqueeze(-1).numpy()
# Initialize a plate for each observation
with numpyro.plate('data', I):
# Mask observations that exceed sample size in the respective plate
with numpyro.handlers.mask(mask = i < sample_sizes):
# Estimate y_hat
numpyro.sample(f"y", Weibull(scale, shape), obs=data)
num_warmup, num_samples = 500, 2000
# Run NUTS.
kernel = NUTS(model)
mcmc = MCMC(kernel, num_samples=3000, num_warmup=1000)
start = time.time()
mcmc.run(key, sample_sizes=sample_sizes, data = data_final)
end = time.time()
print(end-start)
mcmc.print_summary()