Enhance NUTS sampling performance - model fitting is slow

Hello,
I have a hierarchical model that models some experiment measure for experiments that belong to different categories. Each category have different number of experiments, therefore, the input 2d array is ragged and modelled as following:1st dimension is categories, second dimension is the experiments. I use padding and masking as suggested here (Hierarchical Bayesian Model with data of varying length - #2 by fritzo) to work with the ragged structure of the input.
I am fitting the model using NUTS, but it is very slow. For input with dimensions (27, 3500) and using 29 cpus, it takes 19 hours, which is very slow compared to beanmachine that takes 1,5 hours.
I tried using hmc, but it gets stuck on a single chain, and other chains stay stuck at 0 progress:

Here is my code:

def numpyro_model(observations, mask_array, sigma_i):
   category_count, exp_count = observations.shape
    sigma_mu = numpyro.sample("sigma_mu", dist.HalfCauchy(5))
    with numpyro.plate("category", category_count, dim=-2):
        theta = numpyro.sample("theta", dist.Beta(2, 5))
        mu_zero = numpyro.sample("mu_0", dist.Normal(0, 1)        
        # for each experiment in the category
        with numpyro.plate("experiment", exp_count):
                with numpyro.handlers.mask(mask = mask_array):
                    z = numpyro.sample("z", dist.Bernoulli(theta), infer={"enumerate": "parallel"})                   
                    sigma_i = numpyro.deterministic("sigma_i", sigma_i)                    
                    mu = numpyro.sample("mu", dist.Normal(mu_zero, sigma_mu))                      
                    # experiment metric:
                    x = numpyro.sample("x", dist.Normal(mu * z, sigma_i), obs=observations)  

kernel = numpyro.infer.NUTS(numpyro_model)
numpyro_mcmc = numpyro.infer.MCMC(
    kernel,
    num_warmup=2500,
    num_samples=7500,
    num_chains=4,
)
numpyro_mcmc.run(rng_key, observations=x_obs_array, mask_array=mask_array, sigma_i=sigma_i_array)

Any suggestion on how to speed the fitting up? also, on how to set the NUTS or hmc params appropriately.
Thanks!

one thing you can try is reparametrizing, see example 1 here

e.g. something like

 mu = mu_zero + sigma_mu * numpyro.sample("mu", dist.Normal(0, 1))