I’m seeking advice on improving runtime performance of the below numpyro model.
Model Description: I have a dataset of L
objects. For each object, I sample a discrete variable c
and eight continuous variables – s
, h
and six parameters theta_i
which determine an analytical function (defined by the method dst). This function is fit to observed data points, one fit per object. Ideally the fitting code should be in a nested plate, but since the number of observed data points for each object maybe different, I have instead flattened out all the data points as a single vector V
and do the fitting in a separate plate. The total number of data points is SL
. The code and graphical model are at the end of this post.
When testing on a dataset of L ~ 3000, SL ~ 21000
, it takes ~ 5 minutes to run 4 chains with 1000 iterations each (500 warmup, 500 sample) on a machine with 8-core CPU and 32 GB ram. However, when testing on a larger dataset of L ~ 3e5 SL ~ 4e6
, it takes ~1.5 days.
I am seeking advice on how to speed-up the sampling. I suspect that a large vectors of million data points is causing performance problems for numpyro. Is there any way to tell the sampler to run things in batches of smaller size?
The code. The distribution ImproperTruncatedNormal
is just a Normal
distribution with positive support as suggested in this thread.
class ImproperTruncatedNormal(dist.Normal):
support = dist.constraints.positive
funsor.distribution.make_dist(ImproperTruncatedNormal, param_names=("loc", "scale"))
def dst(theta, time):
return theta[..., 0] + 0.5*theta[..., 1] * (
jnp.tanh(theta[..., 4] * (time - theta[..., 2])) -
jnp.tanh(theta[..., 5] * (time - theta[..., 3]))
)
def my_model(V_obs, t, index_mapping, L, pi, theta_mean, theta_std, s_line_fit_params, h_line_fit_params, s_prior, h_prior, sigma, SL):
with numpyro.plate("L", L):
c = numpyro.sample("c", dist.Categorical(jnp.array(pi)), infer={'enumerate': 'parallel'})
s = numpyro.sample("s", dist.Normal(loc=jnp.array(s_prior[c, 0]), scale=jnp.array(s_prior[c, 1])))
h = numpyro.sample("h", dist.Normal(loc=jnp.array(h_prior[c, 0]), scale=jnp.array(h_prior[c, 1])))
theta_1 = numpyro.sample("theta_1", dist.Normal(loc=jnp.array(theta_mean[c, 0]), scale=jnp.array(theta_std[c, 0])))
theta_2 = numpyro.sample("theta_2", ImproperTruncatedNormal(loc=jnp.array(theta_mean[c, 1]), scale=jnp.array(theta_std[c, 1])))
theta_5 = numpyro.sample("theta_5", ImproperTruncatedNormal(loc=jnp.array(theta_mean[c, 2]), scale=jnp.array(theta_std[c, 4])))
theta_6 = numpyro.sample("theta_6", ImproperTruncatedNormal(loc=jnp.array(theta_mean[c, 3]), scale=jnp.array(theta_std[c, 5])))
gamma_3 = jnp.array(s_line_fit_params[c, 0] + s * s_line_fit_params[c, 1])
gamma_4 = jnp.array(h_line_fit_params[c, 0] + h * h_line_fit_params[c, 1])
gamma_length = gamma_4 - gamma_3
sigma_gamma_length = jnp.sqrt(theta_std[c, 3]**2 - theta_std[c, 2]**2)
theta_3 = numpyro.sample("theta_3", ImproperTruncatedNormal(loc=gamma_3, scale=theta_std[c, 2]))
length = numpyro.sample("length", ImproperTruncatedNormal(loc=gamma_length, scale=sigma_gamma_length))
theta_4 = numpyro.deterministic("theta_4", theta_3 + length)
theta = numpyro.deterministic("theta", jnp.stack([theta_1, theta_2, theta_3, theta_4, theta_5, theta_6], axis=-1))
with numpyro.plate("SL", SL):
v_t = dst(theta[..., index_mapping, :], t)
V = numpyro.sample("V", dist.Normal(v_t, sigma), obs=V_obs)
The graphical model