Reusing sample sites of a callable

Hello everyone

I would like to know if it is possible to achieve something like the toy example below, that would reuse a sample site:

import jax
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.infer import SVI
from numpyro.infer.autoguide import AutoDelta


def scale(x):
    
    factor = numpyro.sample("factor", dist.HalfNormal(1))
    return x * factor

def model(x, obs):
    scaled1 = scale(x)
    
    # Can I call it again
    # in a way that the sample sites return the same values as the first call?
    scaled2 = scale(x)
    
    numpyro.sample("obs1", dist.Normal(scaled1, 0.1), obs=obs)

# Use SVI to infer

svi = SVI(model, guide=AutoDelta(model), optim=numpyro.optim.Adam(1e-3),
          loss=numpyro.infer.Trace_ELBO())

svi.run(
    rng_key=jax.random.PRNGKey(0),
    progress_bar=True,
    num_steps=100,
    x=jnp.arange(100),
    obs=jnp.arange(100)*2

)

With this code, I get “AssertionError: all sites must have unique names but got factor duplicated”, as expected.

Can the re-use of the sites in scale function be achieve with an existing handler? Or would is it possible to implement such a handler?

Thanks!

I guess you want something like Automatic Name Generation — Pyro documentation Please make a feature request for it.

Thank you for such a fast answer!

As far as I understand, the automatic name generation creates a new site, without the same value of the previous one. Is that right?

In my case, I would need to obtain the same value of numpyro.sample("factor"), i.e., scaled1 would be equal to scaled2 in this toy example.

I though replay handler could do this but in my experiments it didn’t work.

I guess you can do something along the line

    with numpyro.handlers.trace() as tr:
        scaled1 = scale(x)

    with numpyro.handlers.block(), numpyro.handlers.replay(trace=tr):
        scaled2 = scale(x)

It works!
But maybe I’ve found a bug, could you kindly confirm if it indeed is an unexpected error?
If I wrap the code with scope handler, I get an assertion error


import jax
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.infer import SVI
from numpyro.infer.autoguide import AutoDelta


def scale(x):

    factor = numpyro.sample("factor", dist.HalfNormal(1))
    return x * factor


def model(x, obs=None):
    
    with numpyro.handlers.scope(prefix="test/"):
        
        
        with numpyro.handlers.trace() as tr:
            scaled1 = scale(x)

        with numpyro.handlers.block(), numpyro.handlers.replay(trace=tr):
            scaled2 = scale(x)

        numpyro.deterministic("value_scaled1", scaled1)
        numpyro.deterministic("value_scaled2", scaled2)


    numpyro.sample("obs1", dist.Normal(scaled1, 0.1), obs=obs)


guide = AutoDelta(model)
svi = SVI(
    model,
    guide=guide,
    optim=numpyro.optim.Adam(1),
    loss=numpyro.infer.Trace_ELBO(),
)

x = jnp.arange(100)
run_results = svi.run(
    rng_key=jax.random.PRNGKey(0),
    progress_bar=True,
    num_steps=100,
    x=x,
    obs=x * 2,
)
posterior_samples = guide.sample_posterior(
    jax.random.PRNGKey(0), params=run_results.params, x=x
)

The error message:

644 def sample(self, key, sample_shape=()):
--> 645     assert is_prng_key(key)
    646     return jnp.abs(self._normal.sample(key, sample_shape))

Without the scope handler it works perfectly

block does not allow scope to see the site.

1 Like

Thank you :slight_smile: