Vmap, Scan, and numpyro.sample

Hi All,
i need a bit of help adding in a numpyro.sample statement (or an equivalent stochastic statement) inside of jax.lax.scan.

The steps of the code that are giving me trouble are below.
The steps are

  1. Vmap across a set of parameters
  2. Inside the Vmap, there exists a scan
  3. Inside the scan there is a statement: stochastic_states = numpyro.sample(f'stochastic_states_{t}', dist.Dirichlet(y) )

If i run the below code using jax.lax.scan a “UnexpectedTracerError” error is returned
If i run the below code replacing the jax.lax.scan with a scan from numpyro.contrib.control_flow then a shape mismatch error is returned

The code works/runs without the sample statement.

How do i include a stochastic component inside my scan?

Thanks for the help

        def step(carry,arr, M,ttls, beta_a1_beta2_a2):
            s,I,i,H,h,R = carry
            t,beta1,ky = arr

            beta2,a1,a2,alpha_removed_i,alpha_removed_h = beta_a1_beta2_a2

            a1  = beta1.reshape(num_locations,1)*(a1.reshape(num_locations,LAG))
            a2  = (a2.reshape(num_locations,LAG))

            contributions = ( ((i*a1).sum(1)).reshape(num_locations,1)  )
            l             = jnp.exp(-1*jnp.matmul(M,contributions)  )

            #--iniital proportions
            init_s   = s[:,0].reshape(-1,1)
            init_I   = I[:,0].reshape(-1,1)
            init_i   = i[:,0].reshape(-1,1)
            init_H   = H[:,0].reshape(-1,1)
            init_h   = h[:,0].reshape(-1,1)
            init_R   = R[:,0].reshape(-1,1)

            #--new proportions
            new_s     = (init_s*l).reshape(-1,1)

            new_i    = (init_s*(1.-l)).reshape(-1,1)
            new_I    = (init_I+new_i - alpha_removed_i*init_I).reshape(-1,1)
            
            new_h     = (a2*i).sum(1).reshape(num_locations ,1)
            new_H     = (init_H+new_h - alpha_removed_h*init_H).reshape(-1,1)

            new_R     = (init_R + alpha_removed_i*new_I + alpha_removed_h*new_H).reshape(-1,1)
            
            #--add stochastic component
            y = jnp.hstack([new_s,new_I,new_i, new_H, new_h, new_R])#.T
            
           #THIS IS THE PROBLEM
            stochastic_states   = numpyro.sample(f'stochastic_states_{t}', dist.Dirichlet(y) )
           #---
            x = jnp.delete(carry,-1,axis=2)
            y = stochastic_states.T

            z = jnp.zeros( (6,num_locations,LAG) )
            z = z.at[..., 1: ].set( x )
            z = z.at[...,0].set(y.reshape(6,num_locations))

            return z,y
        #----------------------------------------------------------------
        #--COLLECT STATES------------------------------------------------
        ttls = jnp.matmul(M,N).reshape(-1,1)
        def take_step(z,inits):
            s0,I0,i0,H0,h0,R0 = inits

            b1,beta2,a1,a2,alpha_removed_i,alpha_removed_h = z
            z=(beta2,a1,a2,alpha_removed_i,alpha_removed_h)

            states     = jnp.stack([s0,I0,i0,H0,h0,R0])
            rints      = jax.random.split( jax.random.PRNGKey(np.random.randint(0,1000)) ,time_units)
            def stoch_step(x,y):
                return step(x,y,M,ttls,z)
            _,scanned  = jax.lax.scan( stoch_step , init = states, xs = (jnp.arange(time_units),b1.T,rints))

            states = jnp.moveaxis(states,[2],[0])
            rslt = jnp.vstack( [ states, scanned])
            return rslt[-(time_units):,...]
        states=jax.vmap( lambda a,b: numpyro.handlers.seed( lambda z,inits: take_step(z,inits), jax.random.PRNGKey(1))( a,b )    )( (beta1,beta2.T,a1.T,a2.T,alpha_removed_i,alpha_removed_h), (s0,I0,i0,H0,h0,R0) )

Hi @tommcandrew, could you make a small repro code?

Absolutely. i’ll make one today and give a fuller report

Hi @fehiepsi

i have built a minimal example but it looks like i bumped into another error.
The bug is in the line stochastic_states = numpyro.sample(f'stochastic_state_{t}', dist.Dirichlet(last_state) )

It looks like t is recognized as a tracer value and not a typical float so that i can rename this site iteratively.

If that line is commented out then the code runs.

Thanks again for the help,
tom

import numpy as np
import jax
from jax import jit
import jax.numpy as jnp

import numpyro
import numpyro.distributions as dist
from numpyro.contrib.control_flow import scan
import numpyro.handlers


if __name__ == "__main__":

    def example():
        def step(carry,arr):
            state = carry
            t = arr

            last_state = state[:,0]
            stochastic_states   = numpyro.sample(f'stochastic_state_{t}', dist.Dirichlet(last_state) )

            x = jnp.delete(carry,-1,axis=1)
            y = last_state #stochastic_states.T.reshape(5,1)

            z = jnp.zeros( (5,10) )
            z = z.at[:,1:].set( x )
            z = z.at[:,0].set(y.reshape(5,))
            
            return z,y
        #----------------------------------------------------------------
        #--COLLECT STATES------------------------------------------------
        z = jnp.zeros( (5,10) )

        #option one with error
        states = jax.lax.scan( step, init=z, xs=(jnp.arange(45)) )

        #option two with error
        states = scan( step, init=z, xs=(jnp.arange(45)) )
        

    from numpyro.infer import MCMC, NUTS,init_to_median
    mcmc = MCMC(
        NUTS(example
             , dense_mass=False, max_tree_depth=5,init_strategy = init_to_median(num_samples=100) ) #)init_to_sample() )
        , num_warmup  = 100
        , num_samples = 200
        , num_chains  = 1
        , thinning    = 2
    )
    mcmc.run(jax.random.PRNGKey(20200320))

Yes, t is a tracer. I think you don’t need to specify t in the name. See the example in scan doc.

@fehiepsi i thought the same, but when i remove t then a “repeated site name” error is returned.

Update @fehiepsi

i have an updated example.
The code in example1 runs without error.
The code in example2 returns an error with an unexpected tracer.
The code in example3 returns no error
The code in example4 returns an error: AssertionError: all sites must have unique names but got stochastic_state duplicated

Example1 i call the scan without using vmap
Example2 uses vmap
Example3 places the scan inside of another function. This example was to exclude the possibility that placing scan inside of another function could cause an issue.
Example4 tries to replace the vmap with a simpler for loop.

Is there something special about vmap (or for loop in example4) that i need to take care of for this to work?

thanks for the help,
tom




#mcandrew
import numpy as np
import jax
from jax import jit
import jax.numpy as jnp

import numpyro
import numpyro.distributions as dist
from numpyro.contrib.control_flow import scan
import numpyro.handlers


if __name__ == "__main__":

    def example1():
        def step(carry,arr):
            state = carry
            t = arr

            last_state = state[:,0]
            stochastic_states   = numpyro.sample("stochastic",dist.Normal(jnp.ones(5,),1)) #numpyro.sample('stochastic_state', dist.Dirichlet(last_state) )

            x = jnp.delete(carry,-1,axis=1)
            y = stochastic_states.T.reshape(5,1)

            z = jnp.zeros( (5,10) )
            z = z.at[:,1:].set( x )
            z = z.at[:,0].set(y.reshape(5,))
            
            return z,y
        #----------------------------------------------------------------
        #--COLLECT STATES------------------------------------------------
        z = jnp.ones( (5,10) )

        #option two with error
        _,states = scan( step, init=z, xs=(jnp.arange(45)) )
        print(states.shape)
        
    def example2():
        def step(carry,arr):
            state = carry
            t = arr

            last_state = state[:,0]
            stochastic_states   = numpyro.sample('stochastic_state', dist.Dirichlet(last_state) )

            x = jnp.delete(carry,-1,axis=1)
            y = stochastic_states.T.reshape(5,1)

            z = jnp.zeros( (5,10) )
            z = z.at[:,1:].set( x )
            z = z.at[:,0].set(y.reshape(5,))
            
            return z,(y,x)
        #----------------------------------------------------------------
        #--COLLECT STATES------------------------------------------------
        z = jnp.ones( (5,10) )

        #option two with error
        def to_map(z):
            _,states = scan( step, init=z, xs=(jnp.arange(45)) )
            return states
        states = jax.vmap( to_map )( jnp.repeat( z[jnp.newaxis,...],3,axis=0 ) )

    def example3():
        def step(carry,arr):
            state = carry
            t = arr

            last_state = state[:,0]
            stochastic_states   = numpyro.sample('stochastic_state', dist.Dirichlet(last_state) )

            x = jnp.delete(carry,-1,axis=1)
            y = stochastic_states.T.reshape(5,1)

            z = jnp.zeros( (5,10) )
            z = z.at[:,1:].set( x )
            z = z.at[:,0].set(y.reshape(5,))
            
            return z,(y,x)
        #----------------------------------------------------------------
        #--COLLECT STATES------------------------------------------------
        z = jnp.ones( (5,10) )

        #option two with error
        def to_map(z):
            _,states = scan( step, init=z, xs=(jnp.arange(45)) )
            return states
        states =  to_map(z)

    def example4():
        def step(carry,arr,t):
            state = carry
            t = arr

            last_state = state[:,0]
            stochastic_states   = numpyro.sample('stochastic_state', dist.Dirichlet(last_state) )

            x = jnp.delete(carry,-1,axis=1)
            y = stochastic_states.T.reshape(5,1)

            z = jnp.zeros( (5,10) )
            z = z.at[:,1:].set( x )
            z = z.at[:,0].set(y.reshape(5,))
            
            return z,(y,x)
        #----------------------------------------------------------------
        #--COLLECT STATES------------------------------------------------
        z = jnp.ones( (5,10) )

        #option two with error
        def to_map(z,t):
            _,states = scan( lambda x,y :step(x,y,t), init=z, xs=(jnp.arange(45)) )
            return states
        
        for t,z in enumerate(jnp.repeat( z[jnp.newaxis,...],3,axis=0 )):
            states =  to_map(z, t)
            
    from numpyro.infer import MCMC, NUTS,init_to_median
    mcmc = MCMC(
        NUTS(example4
             , dense_mass=False, max_tree_depth=5,init_strategy = init_to_median(num_samples=100) ) #)init_to_sample() )
        , num_warmup  = 100
        , num_samples = 200
        , num_chains  = 1
        , thinning    = 2
    )
    mcmc.run(jax.random.PRNGKey(20200320))


re for loop: we need to specify different latent names in its body

re vmap: we don’t support vmap over non-closed numpyro programs. You might consider using plate in the scan body function here.