Using plate and vectorization for differently sized sequences

I have tried a few other search terms and have found a solution in another thread from which I have adapted the code. Does this make sense?

import jax.numpy as jnp

import numpyro
from numpyro.infer import MCMC, NUTS
import numpy as np
import time
import torch

import numpyro
from numpyro.distributions import Weibull, HalfCauchy, Uniform
import jax.random

# Generate Data
key = jax.random.PRNGKey(123)
alpha_true = 2
beta_true = 3
N1 = 55
data = jax.random.weibull_min(key, scale=alpha_true, concentration=beta_true, shape=(N1,))

alpha_true_2 = 6
beta_true_2 = 5
N2 = 100
data2 = jax.random.weibull_min(key, scale=alpha_true_2, concentration=beta_true_2, shape=(N2,))

alpha_true_3 = 9
beta_true_3 = 12
N3 = 1
data3 = jax.random.weibull_min(key, scale=alpha_true_3, concentration=beta_true_3, shape=(N3,))

data_list = [data, data2, data3]
sample_sizes = np.array([N1, N2, N3])
N_Max = max(sample_sizes)

dl = []
for n, item in zip(sample_sizes, data_list):
    pads = jnp.ones((N_Max-n)) * .01
    item = jnp.concatenate([item, pads])
    dl.append(item)

# Combine samples
data_final = jnp.array(dl, order='K').T

def model(sample_sizes, data = None):
    
    # Define the number of groups from which we've attained samples
    n_groups =  sample_sizes.shape[0]
    uniform_data = (0.5, 5)
    common_prior_scale = numpyro.sample("common_scale", Uniform(*uniform_data))
    common_prior_shape = numpyro.sample("common_shape", Uniform(*uniform_data))
    
    # Initialize a plate for each group
    with numpyro.plate("n_groups", n_groups):
        
        # Set priors within each group
        scale = numpyro.sample(f"scale", HalfCauchy(common_prior_scale))
        shape = numpyro.sample(f"shape", HalfCauchy(common_prior_shape))
        
        # Calculate maximum group size
        I = sample_sizes.max().item()
        # Create a range of 0:I
        i = torch.arange(I).unsqueeze(-1).numpy()
        
        # Initialize a plate for each observation
        with numpyro.plate('data', I):
            # Mask observations that exceed sample size in the respective plate
            with numpyro.handlers.mask(mask = i < sample_sizes): 
                # Estimate y_hat
                numpyro.sample(f"y", Weibull(scale, shape), obs=data)


num_warmup, num_samples = 500, 2000

# Run NUTS.
kernel = NUTS(model)
mcmc = MCMC(kernel, num_samples=3000, num_warmup=1000)

start = time.time()
mcmc.run(key, sample_sizes=sample_sizes, data = data_final)
end = time.time()
print(end-start)
mcmc.print_summary()