Hi All,
i need a bit of help adding in a numpyro.sample statement (or an equivalent stochastic statement) inside of jax.lax.scan.
The steps of the code that are giving me trouble are below.
The steps are
- Vmap across a set of parameters
- Inside the Vmap, there exists a scan
- Inside the scan there is a statement:
stochastic_states = numpyro.sample(f'stochastic_states_{t}', dist.Dirichlet(y) )
If i run the below code using jax.lax.scan a “UnexpectedTracerError” error is returned
If i run the below code replacing the jax.lax.scan with a scan from numpyro.contrib.control_flow then a shape mismatch error is returned
The code works/runs without the sample statement.
How do i include a stochastic component inside my scan?
Thanks for the help
def step(carry,arr, M,ttls, beta_a1_beta2_a2):
s,I,i,H,h,R = carry
t,beta1,ky = arr
beta2,a1,a2,alpha_removed_i,alpha_removed_h = beta_a1_beta2_a2
a1 = beta1.reshape(num_locations,1)*(a1.reshape(num_locations,LAG))
a2 = (a2.reshape(num_locations,LAG))
contributions = ( ((i*a1).sum(1)).reshape(num_locations,1) )
l = jnp.exp(-1*jnp.matmul(M,contributions) )
#--iniital proportions
init_s = s[:,0].reshape(-1,1)
init_I = I[:,0].reshape(-1,1)
init_i = i[:,0].reshape(-1,1)
init_H = H[:,0].reshape(-1,1)
init_h = h[:,0].reshape(-1,1)
init_R = R[:,0].reshape(-1,1)
#--new proportions
new_s = (init_s*l).reshape(-1,1)
new_i = (init_s*(1.-l)).reshape(-1,1)
new_I = (init_I+new_i - alpha_removed_i*init_I).reshape(-1,1)
new_h = (a2*i).sum(1).reshape(num_locations ,1)
new_H = (init_H+new_h - alpha_removed_h*init_H).reshape(-1,1)
new_R = (init_R + alpha_removed_i*new_I + alpha_removed_h*new_H).reshape(-1,1)
#--add stochastic component
y = jnp.hstack([new_s,new_I,new_i, new_H, new_h, new_R])#.T
#THIS IS THE PROBLEM
stochastic_states = numpyro.sample(f'stochastic_states_{t}', dist.Dirichlet(y) )
#---
x = jnp.delete(carry,-1,axis=2)
y = stochastic_states.T
z = jnp.zeros( (6,num_locations,LAG) )
z = z.at[..., 1: ].set( x )
z = z.at[...,0].set(y.reshape(6,num_locations))
return z,y
#----------------------------------------------------------------
#--COLLECT STATES------------------------------------------------
ttls = jnp.matmul(M,N).reshape(-1,1)
def take_step(z,inits):
s0,I0,i0,H0,h0,R0 = inits
b1,beta2,a1,a2,alpha_removed_i,alpha_removed_h = z
z=(beta2,a1,a2,alpha_removed_i,alpha_removed_h)
states = jnp.stack([s0,I0,i0,H0,h0,R0])
rints = jax.random.split( jax.random.PRNGKey(np.random.randint(0,1000)) ,time_units)
def stoch_step(x,y):
return step(x,y,M,ttls,z)
_,scanned = jax.lax.scan( stoch_step , init = states, xs = (jnp.arange(time_units),b1.T,rints))
states = jnp.moveaxis(states,[2],[0])
rslt = jnp.vstack( [ states, scanned])
return rslt[-(time_units):,...]
states=jax.vmap( lambda a,b: numpyro.handlers.seed( lambda z,inits: take_step(z,inits), jax.random.PRNGKey(1))( a,b ) )( (beta1,beta2.T,a1.T,a2.T,alpha_removed_i,alpha_removed_h), (s0,I0,i0,H0,h0,R0) )