–Data generation/simulation
Suppose that we observe the number of incident cases of an infectious agent over 33 timepoints (a 1X33 vector). In addition, we observe this infectious agent over 9 seasons. Then we have a dataset Y of dimensions 9X33 where each row corresponds to a season and each column corresponds to the number of incident cases observed at that timepoint.
We hypothesize that the above data set depends on parameters: sigma, gamma, N, and lamba. All parameters are fixed except lamba. Let lambs = [ 2, 1./2 ]. Then for each season, lamb is assigned the value 2 (ie lamb[1]) with probability 0.6 and the value 1./2 (ie lamb[2]) with probability 0.4.
Incident cases can be computed from this parameter set theta = (lamb,sigma,gamma,N) and then the observed incident cases are subject to noise.
Below is the code to generate this synthetic data.
import numpy as np
import jax.numpy as jnp
import scipy.integrate
#--SIMULATION OF DATA MATRIX Y
#--Fixed population size (N)
N=1000
#--33 weeks of observation
timepoints = 33
#--suppose we observe 5 seasons
seasons = 9
#-model parameters
#--parameters that control dynamics
sigma = 1./5
gamma = 1./10
lambs = [2, 1./2]
#--parameters that control surveillance
catchment = 1./20
#--initial conditions
e0 = 0.01
i0 = 0.01
r0 = 0.00
c0 = i0
s0 = 1. - (e0+i0+r0+c0)
init = np.array([s0,e0,i0,r0,c0])
#--probability of choosing lambs[1] versus lambs[2]
probs = np.array([0.6,0.4])
assignments = []
for n,season in enumerate(range(seasons)):
#--assign lamb[1] or lamb[2] to season "season"
assign = np.random.choice([0,1],p=probs)
lamb = lambs[assign]
#--form vector of parameters
theta = np.array([lamb, sigma, gamma, N])
#--integrate model
#--ode specification
def seir(y,t,theta):
s,e,i,r, c = y
lamb,sigma,gamma,N = theta
ds = (1/N)*(-1.*lamb*s*(N*i))
de = (1/N)*(lamb*s*(N*i) - sigma*(N*e))
di = (1/N)*(sigma*(N*e) - gamma*(N*i))
dr = (1/N)*(gamma*(N*i))
dc = (1/N)*(sigma*(N*e))
return jnp.stack([ds,de,di,dr, dc])
states = scipy.integrate.odeint( seir, y0 = init, t = np.arange(timepoints), args = (theta,) )
#--compute incident cases from cumulative incident cases
incident_cases = np.append( 0, np.diff(states[:,-1]))
#--add noise
noisy_cases = np.random.poisson(incident_cases*N)
noisy_cases = noisy_cases.reshape(1,-1)
#--assignments
assignments.append(assign)
#--append to matrix
if n==0:
Y = noisy_cases
else:
Y = np.vstack([Y,noisy_cases])
i would like to fit the assumed model to the above data.
The code to fit my model in numpyro is below
import jax
from jax.experimental.ode import odeint
from jax.random import PRNGKey
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC,NUTS
def model(Y,N,K=2,forecast=0):
seasons,timepoints = Y.shape
#--model diffeq
def seir(y,t,theta):
s,e,i,r, c = y
lamb,sigma,gamma,N = theta
ds = (1/N)*(-1.*lamb*s*(N*i))
de = (1/N)*(lamb*s*(N*i) - sigma*(N*e))
di = (1/N)*(sigma*(N*e) - gamma*(N*i))
dr = (1/N)*(gamma*(N*i))
dc = (1/N)*(sigma*(N*e))
return jnp.stack([ds,de,di,dr, dc])
sigma = numpyro.sample("sigma" , dist.Beta(.5,.5))
gamma = numpyro.sample("gamma" , dist.Uniform(1./10,1./2))
catchment = numpyro.sample("catchment", dist.Beta(1,100))
mix_weights = numpyro.sample("weight" , dist.Dirichlet(jnp.array([0.5,0.5])) )
with numpyro.plate("components",K,dim=-1):
mix_centers = numpyro.sample("mix_centers", dist.Gamma(2,1) )
#--loglikelihood
sigma2 = numpyro.sample("sigma2", dist.HalfCauchy(1.) )
#--initial conditions
e0 = numpyro.sample("e0", dist.Beta(1,100))
i0 = numpyro.sample("i0", dist.Beta(1,100))
r0 = numpyro.deterministic("r0",0.*e0)
s0 = numpyro.deterministic("s0",(1.-(e0+i0+r0)))
c0 = numpyro.deterministic("c0",i0)
init = jnp.array([s0,e0,i0,r0, c0])
mask = ~jnp.isnan(Y) #--incase there are missing values
with numpyro.plate("seasons", seasons, dim=-1):
assignment = numpyro.sample("assignment", dist.Categorical(mix_weights))
lambs = mix_centers[assignment]
thetas = jnp.hstack([lambs.reshape(-1,1)
, jnp.repeat(sigma,seasons).reshape(-1,1)
, jnp.repeat(gamma,seasons).reshape(-1,1)
, jnp.repeat(jnp.array([N]),seasons).reshape(-1,1)] )
states = jax.vmap( lambda theta: odeint( seir
, init
, jnp.arange(0.,timepoints)
, theta
, rtol=1e-6
, atol=1e-5
, mxstep=1000) )(thetas)
inc_cases = jnp.clip( jnp.append( jnp.repeat(i0,seasons).reshape(-1,1), jnp.diff(states[:,:,-1]),1 ),0,jnp.inf)
inc_cases = numpyro.deterministic("inc_cases",N*inc_cases)
inc_cases = inc_cases
obs_cases = inc_cases*catchment
with numpyro.handlers.mask(mask=mask):
numpyro.sample("ll", dist.Normal( obs_cases, sigma2), obs = Y )
mcmc = MCMC(
NUTS(model, dense_mass=False)
, num_warmup = 8000
, num_samples = 1000
, num_chains = 1
, thinning = 5
)
with numpyro.handlers.seed(rng_seed=1):
trace = numpyro.handlers.trace(model).get_trace(Y=Y,N=N)
print(numpyro.util.format_shapes(trace))
mcmc.run(PRNGKey(20200320) #--seed
, Y = Y
, N = N
, forecast = 4
)
mcmc.print_summary()
samples = mcmc.get_samples()
The error message that i cannot solve is
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
File /usr/local/lib/python3.10/site-packages/jax/_src/lax/lax.py:148, in broadcast_shapes(*shapes)
147 try:
--> 148 return _broadcast_shapes_cached(*shapes)
149 except:
File /usr/local/lib/python3.10/site-packages/jax/_src/util.py:263, in cache.<locals>.wrap.<locals>.wrapper(*args, **kwargs)
262 else:
--> 263 return cached(config._trace_context(), *args, **kwargs)
File /usr/local/lib/python3.10/site-packages/jax/_src/util.py:256, in cache.<locals>.wrap.<locals>.cached(_, *args, **kwargs)
254 @functools.lru_cache(max_size)
255 def cached(_, *args, **kwargs):
--> 256 return f(*args, **kwargs)
File /usr/local/lib/python3.10/site-packages/jax/_src/lax/lax.py:154, in _broadcast_shapes_cached(*shapes)
152 @cache()
153 def _broadcast_shapes_cached(*shapes: tuple[int, ...]) -> tuple[int, ...]:
--> 154 return _broadcast_shapes_uncached(*shapes)
File /usr/local/lib/python3.10/site-packages/jax/_src/lax/lax.py:170, in _broadcast_shapes_uncached(*shapes)
169 if result_shape is None:
--> 170 raise ValueError(f"Incompatible shapes for broadcasting: shapes={list(shapes)}")
171 return result_shape
ValueError: Incompatible shapes for broadcasting: shapes=[(9,), (9, 33)]
During handling of the above exception, another exception occurred:
ValueError Traceback (most recent call last)
Cell In[9], line 2
1 with numpyro.handlers.seed(rng_seed=1):
----> 2 trace = numpyro.handlers.trace(model).get_trace(Y=Y,N=N)
3 print(numpyro.util.format_shapes(trace))
5 mcmc.run(PRNGKey(20200320) #--seed
6 , Y = Y
7 , N = N
8 , forecast = 4
9 )
File /usr/local/lib/python3.10/site-packages/numpyro/handlers.py:171, in trace.get_trace(self, *args, **kwargs)
163 def get_trace(self, *args, **kwargs):
164 """
165 Run the wrapped callable and return the recorded trace.
166
(...)
169 :return: `OrderedDict` containing the execution trace.
170 """
--> 171 self(*args, **kwargs)
172 return self.trace
File /usr/local/lib/python3.10/site-packages/numpyro/primitives.py:105, in Messenger.__call__(self, *args, **kwargs)
103 return self
104 with self:
--> 105 return self.fn(*args, **kwargs)
Cell In[8], line 70, in model(Y, N, K, forecast)
67 obs_cases = inc_cases*catchment
69 with numpyro.handlers.mask(mask=mask):
---> 70 numpyro.sample("ll", dist.Normal( obs_cases, sigma2), obs = Y )
File /usr/local/lib/python3.10/site-packages/numpyro/primitives.py:222, in sample(name, fn, obs, rng_key, sample_shape, infer, obs_mask)
207 initial_msg = {
208 "type": "sample",
209 "name": name,
(...)
218 "infer": {} if infer is None else infer,
219 }
221 # ...and use apply_stack to send it to the Messengers
--> 222 msg = apply_stack(initial_msg)
223 return msg["value"]
File /usr/local/lib/python3.10/site-packages/numpyro/primitives.py:47, in apply_stack(msg)
45 pointer = 0
46 for pointer, handler in enumerate(reversed(_PYRO_STACK)):
---> 47 handler.process_message(msg)
48 # When a Messenger sets the "stop" field of a message,
49 # it prevents any Messengers above it on the stack from being applied.
50 if msg.get("stop"):
File /usr/local/lib/python3.10/site-packages/numpyro/primitives.py:546, in plate.process_message(self, msg)
544 overlap_idx = max(len(expected_shape) - len(dist_batch_shape), 0)
545 trailing_shape = expected_shape[overlap_idx:]
--> 546 broadcast_shape = lax.broadcast_shapes(
547 trailing_shape, tuple(dist_batch_shape)
548 )
549 batch_shape = expected_shape[:overlap_idx] + broadcast_shape
550 msg["fn"] = msg["fn"].expand(batch_shape)
File /usr/local/lib/python3.10/site-packages/jax/_src/lax/lax.py:150, in broadcast_shapes(*shapes)
148 return _broadcast_shapes_cached(*shapes)
149 except:
--> 150 return _broadcast_shapes_uncached(*shapes)
File /usr/local/lib/python3.10/site-packages/jax/_src/lax/lax.py:170, in _broadcast_shapes_uncached(*shapes)
168 result_shape = _try_broadcast_shapes(shape_list)
169 if result_shape is None:
--> 170 raise ValueError(f"Incompatible shapes for broadcasting: shapes={list(shapes)}")
171 return result_shape
ValueError: Incompatible shapes for broadcasting: shapes=[(9,), (9, 33)]