I do not know how useful it is, but here are the models I found could be sped up. Pyromodel is the model I use for sampling with numpyro, and then the jaxmodel is the one used for fast evaluation of a single parameter sample. D_H is around 200, n_hidden_layers = 3. I guess that it would be possible to do a automatic translation of a model designed for numpyro to pure jax but I was unsure how to (maybe this would be an interesting feature for Predictive in the long run?).
def pyromodel(X, Y, D_Y,reward_dim, D_H, prior_std, n_hidden_layers, pre_proc, conditional_std = True, diff_model = True):
preproc_x = pre_proc(X)
D_X = preproc_x.shape[1]
data_size = X.shape[0]
def distfun(x,y, invgamma = False):
if prior_std != -1:
if invgamma:
return dist.InverseGamma(x,y)
else:
return dist.Normal(x,y)
else:
return dist.ImproperUniform(dist.constraints.real, (), x.shape)
prior_var = prior_std**2
# sample first layer (we put unit normal priors on all weights)
w1 = numpyro.sample(
"w0", distfun(jnp.zeros((D_X, D_H)), prior_var*jnp.ones((D_X, D_H)))
) # D_X D_H
b1 = numpyro.sample(
"b0", distfun(jnp.zeros((D_H)), prior_var*jnp.ones((D_H)))
) # D_X D_H
z = nonlin(jnp.matmul(preproc_x, w1) + b1) # N D_H <= first layer of activations
for i in range(n_hidden_layers):
# sample second layer
w = numpyro.sample(
"w{}".format(i+1), distfun(jnp.zeros((D_H, D_H)), prior_var*jnp.ones((D_H, D_H)))
)
b = numpyro.sample(
"b{}".format(i+1), distfun(jnp.zeros(D_H), prior_var*jnp.ones(D_H))
) # D_H D_H
h = jnp.matmul(z, w) + b
z = nonlin(h) # N D_H <= second layer of activations
# sample final layer of weights and neural network output
w = numpyro.sample(
"w{}".format(n_hidden_layers+1), distfun(jnp.zeros((D_H, D_Y)), prior_var*jnp.ones((D_H, D_Y)))
)
b = numpyro.sample(
"b{}".format(n_hidden_layers+1), distfun(jnp.zeros((D_Y)), prior_var*jnp.ones((D_Y)))
) # D_H D_Y
zout = jnp.matmul(z, w) + b # N D_Y <= output of the neural network
if diff_model:
if reward_dim:
zout = zout.at[:,:-reward_dim].add(X[:, :D_Y-reward_dim])
else:
zout = zout.at[:,:].add(X[:, :D_Y])
if conditional_std:
wvar = numpyro.sample(
"wvar".format(n_hidden_layers+1), distfun(jnp.ones((D_H, D_Y)), prior_var*jnp.ones((D_H, D_Y)))
)
bvar = numpyro.sample(
"bvar".format(n_hidden_layers + 1), distfun(jnp.ones(D_Y), prior_var * jnp.ones(D_Y))
)
zlogvar = jnp.matmul(z, wvar) + bvar
zlogvar = jnp.clip(zlogvar, -10, 10)
zvar = jnp.exp(zlogvar)
zvar = numpyro.deterministic("zvar", zvar)
else:
zvar = numpyro.sample(
"zvar", dist.distfun(5*jnp.ones((D_Y)), jnp.ones(D_Y), invgamma=True)
) # D_H D_Y
zout = numpyro.deterministic("zout", zout)
numpyro.sample("Y", dist.LowRankMultivariateNormal(loc=zout, cov_factor=jnp.zeros((D_Y, 1)), cov_diag=zvar), obs=Y)
Jaxmodel:
def jaxmodel(params, X, D_Y, reward_dim, n_hidden_layers, pre_proc, key=None, conditional_std=True, diff_model=True,
sample=True, model_id = 0):
preproc_x = pre_proc(X) #If some preprocessing of states
data_size = X.shape[0]
# sample first layer (we put unit normal priors on all weights)
z = nonlin(jnp.matmul(preproc_x, params["w0"][model_id]) + params["b0"][model_id]) # N D_H <= first layer of activations
for i in range(n_hidden_layers):
# sample second layer
h = jnp.matmul(z, params["w{}".format(i + 1)][model_id]) + params["b{}".format(i + 1)][model_id]
z = nonlin(h) # N D_H <= second layer of activations
zout = jnp.matmul(z, params["w{}".format(n_hidden_layers + 1)][model_id]) + params[
"b{}".format(n_hidden_layers + 1)][model_id] # N D_Y <= output of the neural network
if diff_model:
if reward_dim >0:
zout = zout.at[:,:-reward_dim].add(X[:, :D_Y-reward_dim])
else:
zout = zout.at[:,:].add(X[:, :D_Y])
if conditional_std:
zlogvar = jnp.matmul(z, params["wvar"][model_id]) + params["bvar"][model_id]
max_logvar = 5
min_logvar = -3
# zlogvar = max_logvar - jax.nn.softplus(max_logvar - zlogvar)
# zlogvar = min_logvar + jax.nn.softplus(zlogvar - min_logvar)
zlogvar = jnp.clip(zlogvar, -10, 10)
zvar = jnp.exp(zlogvar)
else:
zvar = params["zvar"]
zstd = jnp.sqrt(zvar)
if sample ==False:
return zout, zstd
randomval = jax.random.normal(key, shape=zvar.shape)
return randomval * zstd + zout