Shared parameter update in scan

Hello,
I have a two dimension time series ([time_step, observed_data]) where the first dimension is 100 and the second one is 500. I’d like to update the parameters of my model after exhausting all the time_steps in the first dimension (i.e. mcmc.run on [1, :] → [2, :] → … → [100, :] → update the parameters). I have written the code below (a simplified version is provided here) in which I am using “scan” to iterate over the first dimension of my time series. The problem is that as a result of iteration of “scan” parameter names are no longer unique and I get “AssertionError: all sites must have unique names but got nH_shift duplicated” error. Is there any way to fix this or should I change my approach completely?

Thank you.

from numpyro.contrib.control_flow import scan

def flux(a, b, c = 10):
    some operations
    return flu

def model_slice(carry, slice_y):
    y_prev_nh, y_prev_phi, nh_mean_shift, phi_mean_shift = carry

    nh_shift = numpyro.sample("nH_shift", dist.Normal(nh_mean_shift, 1.0))
    phi_shift = numpyro.sample("phi_shift", dist.Normal(phi_mean_shift, 1.0))
    nh_alpha = numpyro.sample("nH_shift", dist.Normal(10, 1.0))
    phi_alpha = numpyro.sample("phi_shift", dist.Normal(10, 1.0))
    y_t_nh   = numpyro.sample("y_t_nh", dist.Normal(1, 1.0))
    y_t_phi = numpyro.sample("y_t_phi", dist.Normal(10, 2.0))
    
    m_t_nh    = nh_alpha * y_prev_nh + nH_shift
    m_t_phi   = phi_alpha * y_prev_phi - phi_shift

    y_t_nh    = numpyro.sample("y_t_nh", dist.Normal(m_t_nh, 0.1))
    y_t_phi   = numpyro.sample("y_t_phi", dist.Normal(m_t_phi, 0.1))
    
    lambda   = flux(y_t_nh, y_t_phi)
    lambda_mean = numpyro.deterministic("lambda_mean", lambda)
    with numpyro.plate("data", len(slice_y)):
        numpyro.sample("y", dist.Poisson(lambda_mean), obs = slice_y)

    return y_t_nh, y_t_phi, nH_shift, phi_shift


def run_inference_on_slice(carry, slice_y):
    start   = time.time()
    rng_key = jax.random.PRNGKey(0)
    sampler = numpyro.infer.NUTS(model_slice)
    mcmc    = numpyro.infer.MCMC(
        sampler,
        num_warmup  = 1000,
        num_samples = 50000,
        num_chains  = 1, 
        progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True,
    )
    mcmc.run(rng_key, carry = carry, slice_y=slice_y)
    samples          = mcmc.get_samples()
    nh_shift_new  = jnp.mean(samples['nh_shift'])
    phi_shift_new = jnp.mean(samples['phi_shift'])
    y_t_nh_new    = jnp.mean(samples['y_t_nh'])
    y_t_phi_new   = jnp.mean(samples['y_t_phi'])
    
    mcmc.print_summary()
    return (y_t_nh_new, y_t_phi_new, nh_shift_new, phi_shift_new), samples

def scan_fn(carry, slice_y):
    carry, samples = run_inference_on_slice(carry, slice_y)
    return carry, samples

initial_carry = (1.0, 1.0, 1.0, 1.0)
scan(scan_fn, initial_carry, count) ### Where count is the observed data with shape [100, 500]

    

Your approach makes sense. Probably we didn’t block things properly so some messages are left in the pyro stack. Could you try

with numpyro.handlers.block():
    carry, samples = run_inference_on_slice(carry, slice_y)

I appreciate the response. I tried the code above and it seems to suppress that error but now I am getting another error

TypeError: unsupported format string passed to DynamicJaxprTracer.__format__
-------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

I had already ran into this error earlier when I was using “jax.lax.scan” instead of “scan” (imported from “numpyro.contrib.control_flow”). My understanding was that this typically happens because jax.lax.scan was tracing the computation, and some of the values were being traced were DynamicJaxprTracer objects, which could not be formatted directly into strings. Is my understanding correct? To resolve the issue I applied

slice_y_concrete = jax.lax.stop_gradient(slice_y)

inside “run_inference_on_slice” function right before feeding it to mcmc.run but that did not fix the problem! What can I do to fix this issue?

Your interpretation is correct. Maybe commenting out

mcmc.print_summary()

Commenting out “mcmc.print_summary()” does not work! I still get the same error! I tried to print out the type of “carry” and “slice_y” from inside “model_slice” (see below). I should add that I did this with and without “slice_y_concrete = jax.lax.stop_gradient(slice_y)” inside the “run_inference_on_slice” function. Either way I get the error message below! Does this error message come from “run_inference_on_slice” or is it coming from “model_slice”? Because “run_inference_on_slice” is pretty standard!

carry types  ==============> : <class 'tuple'>
slice_y type ==============>: <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>
warmup:   0%|                                                                                                                                                                | 0/60 [00:00<?, ?it/s]
Traceback (most recent call last): 
... the rest of the error ...

maybe you need progress_bar=False. I’m not sure why you want to stop gradient. Maybe you want to use slice_y.shape[0] instead of len

I stripped the model down to just a few parameters and got rid of the “flux” function along with all the print and format commands! Now I have the code below. However, I still get the same “TypeError: unsupported format string passed to DynamicJaxprTracer.format”!

def model_slice(carry, slice_y):
    y_prev_nh, _ = carry

    gamma_value = numpyro.sample("gamma_value", dist.Normal(1, 1.0))
    nh_tau = numpyro.sample("nh_tau", dist.Normal(10, 2.0))

    nh_alpha = jnp.exp(-1 / nh_tau)
    m_t_nh = nh_alpha * y_prev_nh
    y_t_nh = numpyro.sample("y_t_nh", dist.Normal(m_t_nh, 1.0))

    # Simplified constant model
    lambda_i = 10.0
    lambda_mean = numpyro.deterministic("lambda_mean", lambda_i)
    slice_y = slice_y.astype("int32")
    with numpyro.plate("data", slice_y.shape[0]):
        numpyro.sample("y", dist.Poisson(lambda_mean), obs=slice_y)

    return (y_t_nh, _), None

# Define the inference function
def run_inference_on_slice(carry, slice_y):
    rng_key = jax.random.PRNGKey(0)
    sampler = NUTS(model_slice)
    mcmc = MCMC(
        sampler,
        num_warmup=10,
        num_samples=50,
        num_chains=1,
        progress_bar=True
    )
    slice_y_concrete = jax.device_get(slice_y)
    carry_concrete = jax.device_get(carry)
    mcmc.run(rng_key, carry=carry_concrete, slice_y=slice_y_concrete)
    samples = mcmc.get_samples()
    y_t_nh_new = jnp.mean(samples['y_t_nh'])
    return (y_t_nh_new, _), samples

# Define the scan function
def scan_fn(carry, slice_y):
    with numpyro.handlers.block():
        carry, samples = run_inference_on_slice(carry, slice_y)
    return carry, samples

# Initial carry
initial_carry = (1.0, 1.0)

final_carry_params, all_samples = scan(scan_fn, initial_carry, count)

Is it possible the line below is causing the code to crash even though removing it also will cause the code to crash due to non unique name issue?

with numpyro.handlers.block():

Could you try progress_bar=False? I don’t understand why you need those device_get etc. My line of thought is to track which tracers are formatted: e.g. are you formatting values to print out in the progress bar, or are you printing out some summarized values, etc. I don’t think those device_get, device_put, stop_gradient would be helpful in format issues.

Good news! I set progress_bar=False and it resolved the issue. (And I commented out all the “device” and “gradient” related commands too.) Since I have commented out all the print commands and the “progress bar” is set to “False” nothing appears on the screen to show how fast or slow it converges (if at all) which is very inconvenient! Can we somehow make the progress bar appear?

I would use Python loop over the slices and call run_inference_on_slice directly, not through scan. You can define your mcmc object once (probably with jit_model_args=True) and run it multiple times with different input data.

Thank you for your help :slight_smile: