Speed up Numpyro Model

I’m seeking advice on improving runtime performance of the below numpyro model.

Model Description: I have a dataset of L objects. For each object, I sample a discrete variable c and eight continuous variables – s , h and six parameters theta_i which determine an analytical function (defined by the method dst). This function is fit to observed data points, one fit per object. Ideally the fitting code should be in a nested plate, but since the number of observed data points for each object maybe different, I have instead flattened out all the data points as a single vector V and do the fitting in a separate plate. The total number of data points is SL. The code and graphical model are at the end of this post.

When testing on a dataset of L ~ 3000, SL ~ 21000 , it takes ~ 5 minutes to run 4 chains with 1000 iterations each (500 warmup, 500 sample) on a machine with 8-core CPU and 32 GB ram. However, when testing on a larger dataset of L ~ 3e5 SL ~ 4e6, it takes ~1.5 days.

I am seeking advice on how to speed-up the sampling. I suspect that a large vectors of million data points is causing performance problems for numpyro. Is there any way to tell the sampler to run things in batches of smaller size?

The code. The distribution ImproperTruncatedNormal is just a Normal distribution with positive support as suggested in this thread.

class ImproperTruncatedNormal(dist.Normal):
    support = dist.constraints.positive

funsor.distribution.make_dist(ImproperTruncatedNormal, param_names=("loc", "scale"))

def dst(theta, time):
    return theta[..., 0] + 0.5*theta[..., 1] * (
        jnp.tanh(theta[..., 4] * (time - theta[..., 2])) - 
        jnp.tanh(theta[..., 5] * (time - theta[..., 3]))
    )


def my_model(V_obs, t, index_mapping, L, pi, theta_mean, theta_std, s_line_fit_params, h_line_fit_params, s_prior, h_prior, sigma, SL):
    
    with numpyro.plate("L", L):
        c = numpyro.sample("c", dist.Categorical(jnp.array(pi)), infer={'enumerate': 'parallel'})
        
        s = numpyro.sample("s", dist.Normal(loc=jnp.array(s_prior[c, 0]), scale=jnp.array(s_prior[c, 1])))
        h = numpyro.sample("h", dist.Normal(loc=jnp.array(h_prior[c, 0]), scale=jnp.array(h_prior[c, 1])))
        
        
        theta_1 = numpyro.sample("theta_1", dist.Normal(loc=jnp.array(theta_mean[c, 0]), scale=jnp.array(theta_std[c, 0])))
        theta_2 = numpyro.sample("theta_2", ImproperTruncatedNormal(loc=jnp.array(theta_mean[c, 1]), scale=jnp.array(theta_std[c, 1])))
        theta_5 = numpyro.sample("theta_5", ImproperTruncatedNormal(loc=jnp.array(theta_mean[c, 2]), scale=jnp.array(theta_std[c, 4])))
        theta_6 = numpyro.sample("theta_6", ImproperTruncatedNormal(loc=jnp.array(theta_mean[c, 3]), scale=jnp.array(theta_std[c, 5])))
        
        gamma_3 = jnp.array(s_line_fit_params[c, 0] + s * s_line_fit_params[c, 1])
        gamma_4 = jnp.array(h_line_fit_params[c, 0] + h * h_line_fit_params[c, 1])
        gamma_length = gamma_4 - gamma_3
        sigma_gamma_length = jnp.sqrt(theta_std[c, 3]**2 - theta_std[c, 2]**2)
        
        theta_3 = numpyro.sample("theta_3", ImproperTruncatedNormal(loc=gamma_3, scale=theta_std[c, 2]))
        length = numpyro.sample("length", ImproperTruncatedNormal(loc=gamma_length, scale=sigma_gamma_length))
        theta_4 = numpyro.deterministic("theta_4", theta_3 + length)
        
        theta = numpyro.deterministic("theta", jnp.stack([theta_1, theta_2, theta_3, theta_4, theta_5, theta_6], axis=-1))
        
    with numpyro.plate("SL", SL):
        v_t = dst(theta[..., index_mapping, :], t)
        V = numpyro.sample("V", dist.Normal(v_t, sigma), obs=V_obs)

The graphical model

Hi @pankajb64, it is still a research problem to design algorithms that deal with large datasets. I would suggest using SVI with subsampling rather than MCMC. For MCMC, you might want to try Example: Hamiltonian Monte Carlo with Energy Conserving Subsampling — NumPyro documentation

Thanks for sharing that @fehiepsi, I’m trying the HMCECS algorithm for my model but i’m facing issues.
My model has a discrete site c that is to be enumerated over. However, when running SVI, I’m getting an error saying
ValueError: Continuous inference cannot handle discrete sample site 'c'.

I’ve never used SVI before so any advice would be appreciated. Does SVI not work with enumerate discrete variables?

I replaced Trace_ELBO() for the SVI loss with TraceGraph_ELBO, since the documentation suggests that it can deal with discrete latent variables. I’m using AutoGuides for now since I don’t know what guides are and what role they’re supposed to play. I noticed that the documentation says AutoGuides work for continuous variables, so maybe I need to define my own guide?

maybe I need to define my own guide?

Yes. We need a custom guide for TraceGraph_ELBO. Currently, we don’t have support for TraceEnum_ELBO (issue #741), which can work with AutoGuides, yet. Another option is to distribute the likelihood computation into different devices, but it is a pending issue (#1425) - I’m taking a look at it. Sorry, I don’t have better solutions. :frowning: