Hi All,
i need a bit of help adding in a numpyro.sample statement (or an equivalent stochastic statement) inside of jax.lax.scan.
The steps of the code that are giving me trouble are below.
The steps are
- Vmap across a set of parameters
 - Inside the Vmap, there exists a scan
 - Inside the scan there is a statement: 
stochastic_states = numpyro.sample(f'stochastic_states_{t}', dist.Dirichlet(y) ) 
If i run the below code using jax.lax.scan a “UnexpectedTracerError” error is returned
If i run the below code replacing the jax.lax.scan with a scan from numpyro.contrib.control_flow then a shape mismatch error is returned
The code works/runs without the sample statement.
How do i include a stochastic component inside my scan?
Thanks for the help
        def step(carry,arr, M,ttls, beta_a1_beta2_a2):
            s,I,i,H,h,R = carry
            t,beta1,ky = arr
            beta2,a1,a2,alpha_removed_i,alpha_removed_h = beta_a1_beta2_a2
            a1  = beta1.reshape(num_locations,1)*(a1.reshape(num_locations,LAG))
            a2  = (a2.reshape(num_locations,LAG))
            contributions = ( ((i*a1).sum(1)).reshape(num_locations,1)  )
            l             = jnp.exp(-1*jnp.matmul(M,contributions)  )
            #--iniital proportions
            init_s   = s[:,0].reshape(-1,1)
            init_I   = I[:,0].reshape(-1,1)
            init_i   = i[:,0].reshape(-1,1)
            init_H   = H[:,0].reshape(-1,1)
            init_h   = h[:,0].reshape(-1,1)
            init_R   = R[:,0].reshape(-1,1)
            #--new proportions
            new_s     = (init_s*l).reshape(-1,1)
            new_i    = (init_s*(1.-l)).reshape(-1,1)
            new_I    = (init_I+new_i - alpha_removed_i*init_I).reshape(-1,1)
            
            new_h     = (a2*i).sum(1).reshape(num_locations ,1)
            new_H     = (init_H+new_h - alpha_removed_h*init_H).reshape(-1,1)
            new_R     = (init_R + alpha_removed_i*new_I + alpha_removed_h*new_H).reshape(-1,1)
            
            #--add stochastic component
            y = jnp.hstack([new_s,new_I,new_i, new_H, new_h, new_R])#.T
            
           #THIS IS THE PROBLEM
            stochastic_states   = numpyro.sample(f'stochastic_states_{t}', dist.Dirichlet(y) )
           #---
            x = jnp.delete(carry,-1,axis=2)
            y = stochastic_states.T
            z = jnp.zeros( (6,num_locations,LAG) )
            z = z.at[..., 1: ].set( x )
            z = z.at[...,0].set(y.reshape(6,num_locations))
            return z,y
        #----------------------------------------------------------------
        #--COLLECT STATES------------------------------------------------
        ttls = jnp.matmul(M,N).reshape(-1,1)
        def take_step(z,inits):
            s0,I0,i0,H0,h0,R0 = inits
            b1,beta2,a1,a2,alpha_removed_i,alpha_removed_h = z
            z=(beta2,a1,a2,alpha_removed_i,alpha_removed_h)
            states     = jnp.stack([s0,I0,i0,H0,h0,R0])
            rints      = jax.random.split( jax.random.PRNGKey(np.random.randint(0,1000)) ,time_units)
            def stoch_step(x,y):
                return step(x,y,M,ttls,z)
            _,scanned  = jax.lax.scan( stoch_step , init = states, xs = (jnp.arange(time_units),b1.T,rints))
            states = jnp.moveaxis(states,[2],[0])
            rslt = jnp.vstack( [ states, scanned])
            return rslt[-(time_units):,...]
        states=jax.vmap( lambda a,b: numpyro.handlers.seed( lambda z,inits: take_step(z,inits), jax.random.PRNGKey(1))( a,b )    )( (beta1,beta2.T,a1.T,a2.T,alpha_removed_i,alpha_removed_h), (s0,I0,i0,H0,h0,R0) )