How to speed up Predictive

Hi!

I am using Predictive to predict Y for a given set of parameters.

predictive = numpyro.infer.Predictive(model, samples, parallel=True)
pred = predictive(rng_key, X=X, D_Y=D_Y, Y=None, D_H=D_H, prior_std=prior_std)

Here samples are a single set of sampled parameters and X has many samples (and I need to make sequential predictions so I will loop over timesteps for the same set of parameters).

I am finding it surprisingly slow to predict values and is only using a single core. Do I need to use vmap to split data up to get parallelization usage for a single set of parameters?

Maybe it would be worth exporting the parameters to a model clone in jax/pytorch for speedup?

2 Likes

exporting the parameters to a model clone in jax/pytorch for speedup

Sounds reasonable to me if you are using some samplers that are slow (like Gamma distribution). If you are using GPU, setting parallel=True will be helpful. If you are using CPU, setting parallel=False will be faster. If you want to distribute the computation across your cores, you can set batch_ndim=0 and use pmap.

def get_pred(sample, rng_key):
    predictive = numpyro.infer.Predictive(model, samples, batch_ndim=0)
    return predictive(...)

pred = jax.pmap(get_pred)(samples, rng_keys)
1 Like

Thanks for your suggestion.
I ended up porting the model to jax and got a significant speedup.

Glad that you solve the problem. It would be nice if you can share your model so other users might know when they need to use jax directly. May be this is also a chance to improve Predictive.

I do not know how useful it is, but here are the models I found could be sped up. Pyromodel is the model I use for sampling with numpyro, and then the jaxmodel is the one used for fast evaluation of a single parameter sample. D_H is around 200, n_hidden_layers = 3. I guess that it would be possible to do a automatic translation of a model designed for numpyro to pure jax but I was unsure how to (maybe this would be an interesting feature for Predictive in the long run?).

def pyromodel(X, Y, D_Y,reward_dim, D_H, prior_std, n_hidden_layers, pre_proc, conditional_std = True, diff_model = True):
    preproc_x = pre_proc(X)
    D_X = preproc_x.shape[1]
    data_size = X.shape[0]
    def distfun(x,y, invgamma = False):
        if prior_std != -1:
            if invgamma:
                return dist.InverseGamma(x,y)
            else:
                return dist.Normal(x,y)
        else:
            return dist.ImproperUniform(dist.constraints.real, (), x.shape)
    prior_var = prior_std**2
    # sample first layer (we put unit normal priors on all weights)
    w1 = numpyro.sample(
        "w0", distfun(jnp.zeros((D_X, D_H)), prior_var*jnp.ones((D_X, D_H)))
    )  # D_X D_H
    b1 = numpyro.sample(
        "b0", distfun(jnp.zeros((D_H)), prior_var*jnp.ones((D_H)))
    )  # D_X D_H
    z = nonlin(jnp.matmul(preproc_x, w1) + b1)  # N D_H  <= first layer of activations
    for i in range(n_hidden_layers):
        # sample second layer
        w = numpyro.sample(
            "w{}".format(i+1), distfun(jnp.zeros((D_H, D_H)), prior_var*jnp.ones((D_H, D_H)))
        )
        b = numpyro.sample(
            "b{}".format(i+1), distfun(jnp.zeros(D_H), prior_var*jnp.ones(D_H))
        )  # D_H D_H
        h = jnp.matmul(z, w) + b
        z = nonlin(h)  # N D_H  <= second layer of activations


    # sample final layer of weights and neural network output
    w = numpyro.sample(
        "w{}".format(n_hidden_layers+1), distfun(jnp.zeros((D_H, D_Y)), prior_var*jnp.ones((D_H, D_Y)))
    )
    b = numpyro.sample(
        "b{}".format(n_hidden_layers+1), distfun(jnp.zeros((D_Y)), prior_var*jnp.ones((D_Y)))
    )  # D_H D_Y
    zout = jnp.matmul(z, w)  + b  # N D_Y  <= output of the neural network
    if diff_model:
        if reward_dim:
            zout = zout.at[:,:-reward_dim].add(X[:, :D_Y-reward_dim])
        else:
            zout = zout.at[:,:].add(X[:, :D_Y])

    if conditional_std:
        wvar =  numpyro.sample(
        "wvar".format(n_hidden_layers+1), distfun(jnp.ones((D_H, D_Y)), prior_var*jnp.ones((D_H, D_Y)))
        )
        bvar = numpyro.sample(
            "bvar".format(n_hidden_layers + 1), distfun(jnp.ones(D_Y), prior_var * jnp.ones(D_Y))
        )
        zlogvar = jnp.matmul(z, wvar) + bvar
        zlogvar = jnp.clip(zlogvar, -10, 10)
        zvar = jnp.exp(zlogvar)
        zvar = numpyro.deterministic("zvar", zvar)

    else:
        zvar = numpyro.sample(
            "zvar", dist.distfun(5*jnp.ones((D_Y)), jnp.ones(D_Y), invgamma=True)
        )  # D_H D_Y
    zout = numpyro.deterministic("zout", zout)
    numpyro.sample("Y", dist.LowRankMultivariateNormal(loc=zout, cov_factor=jnp.zeros((D_Y, 1)), cov_diag=zvar), obs=Y)

Jaxmodel:

def jaxmodel(params, X, D_Y, reward_dim, n_hidden_layers, pre_proc, key=None, conditional_std=True, diff_model=True,
             sample=True, model_id = 0):
    preproc_x = pre_proc(X) #If some preprocessing of states
    data_size = X.shape[0]
    # sample first layer (we put unit normal priors on all weights)
    z = nonlin(jnp.matmul(preproc_x, params["w0"][model_id]) + params["b0"][model_id])  # N D_H  <= first layer of activations

    for i in range(n_hidden_layers):
        # sample second layer
        h = jnp.matmul(z, params["w{}".format(i + 1)][model_id]) + params["b{}".format(i + 1)][model_id]
        z = nonlin(h)  # N D_H  <= second layer of activations
    zout = jnp.matmul(z, params["w{}".format(n_hidden_layers + 1)][model_id]) + params[
        "b{}".format(n_hidden_layers + 1)][model_id]  # N D_Y  <= output of the neural network
    if diff_model:
        if reward_dim >0:
            zout = zout.at[:,:-reward_dim].add(X[:, :D_Y-reward_dim])
        else:
            zout = zout.at[:,:].add(X[:, :D_Y])

    if conditional_std:
        zlogvar = jnp.matmul(z, params["wvar"][model_id]) + params["bvar"][model_id]
        max_logvar = 5
        min_logvar = -3
        # zlogvar = max_logvar - jax.nn.softplus(max_logvar - zlogvar)
        # zlogvar = min_logvar + jax.nn.softplus(zlogvar - min_logvar)
        zlogvar = jnp.clip(zlogvar, -10, 10)
        zvar = jnp.exp(zlogvar)

    else:
        zvar = params["zvar"]
    zstd = jnp.sqrt(zvar)
    if sample ==False:
        return zout, zstd
    randomval = jax.random.normal(key, shape=zvar.shape)
    return randomval * zstd + zout
1 Like

Thanks a lot! It is interesting that two functions are quite similar but the speed is different. I don’t think that numpyro program adds overheads here. Could you share some dummy data for the model inputs? (I guess we only need preproc_x rather than X) I can take a look this weekend.

This post is could not be more timely, I’m dealing with the same slow predictive function on a sequential model. @emilio do you have a hypothesis as to why converting to Jax sped things up?

jaxmodel is the one used for fast evaluation of a single parameter sample

Yeah, as mentioned in one of the previous comment, you need to set batch_ndim=0 to perform Predictive with 1 sample. This is likely the cause of the slowness.

1 Like

Hi!

@cfusting, unfortunately I do not really know what is the big difference could be, but I think it might boil down to being able to jit-compile pure jax.

@fehiepsi Unfortunately, I do not see a significant difference when using batch_ndims 0/1.

I created an example to show the difference in speed between Predictive/jax. I apologize if it is a bit messy, I just cut and pasted the important parts from my own code. It is also entirely possible that it is due to my lack of numpyro/jax experience. The only thing that I can see that is not numpyro vs pure jax issue is that in my jaxmodel I do manual transformation from univariate gaussian to multivariate isotropic gaussian while in the numpyro version of the model I use the lowrank gaussian, but I doubt this is a major issue.

For GPU usage I found a factor 10 speedup when I jit my jaxmodel and a minor speedup without jit. Maybe the speedup is mostly that I have not been successful in jit-compiling predict (if even possible?)

Thanks, @emilio! It seems that you are comparing non-jitted code to jitted code. To jit your numpyro code, you can do

pyro_predict = jax.jit(lambda key, state: predict(
    pyromodel, key, lastsample, state, D_Y, 0, D_H, n_hidden_layers, -1, True, pre_proc, 0)[0])
tmp = pyro_predict(key[0], state)  # prejit
t0 = time.time()
for i in range(iter):
  state = pyro_predict(key[i], state)
print(time.time() - t0)

which should give you similar performance.

1 Like

Yes, you are totally correct (again). I just thought I was unable to jit numpyro code for some reason.
Numpyro is slightly slower, but not significantly (around 0.9 vs 0.7). Can I expect further speedup by using vmap or does jit handle that itself?

Thank you so much for your help.

Which operator would you want to vmap? If you want to compute for a batch of states then you can set batchndims to 1 (default value) and parallel=True. Then jitting should still work.

thanks