Hello,
I have a two dimension time series ([time_step, observed_data]) where the first dimension is 100 and the second one is 500. I’d like to update the parameters of my model after exhausting all the time_steps in the first dimension (i.e. mcmc.run on [1, :] → [2, :] → … → [100, :] → update the parameters). I have written the code below (a simplified version is provided here) in which I am using “scan” to iterate over the first dimension of my time series. The problem is that as a result of iteration of “scan” parameter names are no longer unique and I get “AssertionError: all sites must have unique names but got nH_shift
duplicated” error. Is there any way to fix this or should I change my approach completely?
Thank you.
from numpyro.contrib.control_flow import scan
def flux(a, b, c = 10):
some operations
return flu
def model_slice(carry, slice_y):
y_prev_nh, y_prev_phi, nh_mean_shift, phi_mean_shift = carry
nh_shift = numpyro.sample("nH_shift", dist.Normal(nh_mean_shift, 1.0))
phi_shift = numpyro.sample("phi_shift", dist.Normal(phi_mean_shift, 1.0))
nh_alpha = numpyro.sample("nH_shift", dist.Normal(10, 1.0))
phi_alpha = numpyro.sample("phi_shift", dist.Normal(10, 1.0))
y_t_nh = numpyro.sample("y_t_nh", dist.Normal(1, 1.0))
y_t_phi = numpyro.sample("y_t_phi", dist.Normal(10, 2.0))
m_t_nh = nh_alpha * y_prev_nh + nH_shift
m_t_phi = phi_alpha * y_prev_phi - phi_shift
y_t_nh = numpyro.sample("y_t_nh", dist.Normal(m_t_nh, 0.1))
y_t_phi = numpyro.sample("y_t_phi", dist.Normal(m_t_phi, 0.1))
lambda = flux(y_t_nh, y_t_phi)
lambda_mean = numpyro.deterministic("lambda_mean", lambda)
with numpyro.plate("data", len(slice_y)):
numpyro.sample("y", dist.Poisson(lambda_mean), obs = slice_y)
return y_t_nh, y_t_phi, nH_shift, phi_shift
def run_inference_on_slice(carry, slice_y):
start = time.time()
rng_key = jax.random.PRNGKey(0)
sampler = numpyro.infer.NUTS(model_slice)
mcmc = numpyro.infer.MCMC(
sampler,
num_warmup = 1000,
num_samples = 50000,
num_chains = 1,
progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True,
)
mcmc.run(rng_key, carry = carry, slice_y=slice_y)
samples = mcmc.get_samples()
nh_shift_new = jnp.mean(samples['nh_shift'])
phi_shift_new = jnp.mean(samples['phi_shift'])
y_t_nh_new = jnp.mean(samples['y_t_nh'])
y_t_phi_new = jnp.mean(samples['y_t_phi'])
mcmc.print_summary()
return (y_t_nh_new, y_t_phi_new, nh_shift_new, phi_shift_new), samples
def scan_fn(carry, slice_y):
carry, samples = run_inference_on_slice(carry, slice_y)
return carry, samples
initial_carry = (1.0, 1.0, 1.0, 1.0)
scan(scan_fn, initial_carry, count) ### Where count is the observed data with shape [100, 500]