Using plate and vectorization for differently sized sequences

Hey there!

I have a quite simple hierarchical (num)pyro model: I am looking at a number of n machine parts for which I want to estimate their mean time between failures. For that, I am using a Weibull distribution. Since the machine parts are independent from each other, they get their own likelihood, however I want them to share a common prior.

Now, I have built up the model (+ sample data)

import numpyro
import numpyro.infer as infer
from numpyro.distributions import Weibull, HalfCauchy, Uniform
import jax.random

key = jax.random.PRNGKey(123)
alpha_true = 2
beta_true = 3
N = 55
data = jax.random.weibull_min(key, scale=alpha_true, concentration=beta_true, shape=(N,))

alpha_true_2 = 6
beta_true_2 = 5
N = 100
data2 = jax.random.weibull_min(key, scale=alpha_true_2, concentration=beta_true_2, shape=(N,))

alpha_true_3 = 9
beta_true_3 = 12
N = 1
data3 = jax.random.weibull_min(key, scale=alpha_true_3, concentration=beta_true_3, shape=(N,))

data_final = [data, data2, data3]

def model(data):
    uniform_data = (0.5, 5)
    common_prior_scale = numpyro.sample("common_scale", Uniform(*uniform_data))
    common_prior_shape = numpyro.sample("common_shape", Uniform(*uniform_data))
    for key, group in enumerate(data):
        scale = numpyro.sample(f"scale_{key}", HalfCauchy(common_prior_scale))
        shape = numpyro.sample(f"shape_{key}", HalfCauchy(common_prior_shape))
        numpyro.sample(f"y_{key}", Weibull(scale, shape).expand([len(group)]), obs=group)

nuts_kernel = infer.NUTS(model)
mcmc = infer.MCMC(nuts_kernel, num_samples=3000, num_warmup=1000)
mcmc.run(key, data_final)

# Summarize the results
mcmc.print_summary()

With numpyro, this is already quite fast, however these are just dummy data and I am expecting more data (and machine parts) to come. Of course, I might use SVI but this will come later. I thought that it must be possible to vectorize the whole thing and use a plate in order to model the conditional independence between these groups, however I am completely lost in implementing this. Any help would be appreciated.

Thanks and best regards,
m

I have tried a few other search terms and have found a solution in another thread from which I have adapted the code. Does this make sense?

import jax.numpy as jnp

import numpyro
from numpyro.infer import MCMC, NUTS
import numpy as np
import time
import torch

import numpyro
from numpyro.distributions import Weibull, HalfCauchy, Uniform
import jax.random

# Generate Data
key = jax.random.PRNGKey(123)
alpha_true = 2
beta_true = 3
N1 = 55
data = jax.random.weibull_min(key, scale=alpha_true, concentration=beta_true, shape=(N1,))

alpha_true_2 = 6
beta_true_2 = 5
N2 = 100
data2 = jax.random.weibull_min(key, scale=alpha_true_2, concentration=beta_true_2, shape=(N2,))

alpha_true_3 = 9
beta_true_3 = 12
N3 = 1
data3 = jax.random.weibull_min(key, scale=alpha_true_3, concentration=beta_true_3, shape=(N3,))

data_list = [data, data2, data3]
sample_sizes = np.array([N1, N2, N3])
N_Max = max(sample_sizes)

dl = []
for n, item in zip(sample_sizes, data_list):
    pads = jnp.ones((N_Max-n)) * .01
    item = jnp.concatenate([item, pads])
    dl.append(item)

# Combine samples
data_final = jnp.array(dl, order='K').T

def model(sample_sizes, data = None):
    
    # Define the number of groups from which we've attained samples
    n_groups =  sample_sizes.shape[0]
    uniform_data = (0.5, 5)
    common_prior_scale = numpyro.sample("common_scale", Uniform(*uniform_data))
    common_prior_shape = numpyro.sample("common_shape", Uniform(*uniform_data))
    
    # Initialize a plate for each group
    with numpyro.plate("n_groups", n_groups):
        
        # Set priors within each group
        scale = numpyro.sample(f"scale", HalfCauchy(common_prior_scale))
        shape = numpyro.sample(f"shape", HalfCauchy(common_prior_shape))
        
        # Calculate maximum group size
        I = sample_sizes.max().item()
        # Create a range of 0:I
        i = torch.arange(I).unsqueeze(-1).numpy()
        
        # Initialize a plate for each observation
        with numpyro.plate('data', I):
            # Mask observations that exceed sample size in the respective plate
            with numpyro.handlers.mask(mask = i < sample_sizes): 
                # Estimate y_hat
                numpyro.sample(f"y", Weibull(scale, shape), obs=data)


num_warmup, num_samples = 500, 2000

# Run NUTS.
kernel = NUTS(model)
mcmc = MCMC(kernel, num_samples=3000, num_warmup=1000)

start = time.time()
mcmc.run(key, sample_sizes=sample_sizes, data = data_final)
end = time.time()
print(end-start)
mcmc.print_summary()