Hello everyone
I would like to know if it is possible to achieve something like the toy example below, that would reuse a sample site:
import jax
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.infer import SVI
from numpyro.infer.autoguide import AutoDelta
def scale(x):
factor = numpyro.sample("factor", dist.HalfNormal(1))
return x * factor
def model(x, obs):
scaled1 = scale(x)
# Can I call it again
# in a way that the sample sites return the same values as the first call?
scaled2 = scale(x)
numpyro.sample("obs1", dist.Normal(scaled1, 0.1), obs=obs)
# Use SVI to infer
svi = SVI(model, guide=AutoDelta(model), optim=numpyro.optim.Adam(1e-3),
loss=numpyro.infer.Trace_ELBO())
svi.run(
rng_key=jax.random.PRNGKey(0),
progress_bar=True,
num_steps=100,
x=jnp.arange(100),
obs=jnp.arange(100)*2
)
With this code, I get “AssertionError: all sites must have unique names but got factor
duplicated”, as expected.
Can the re-use of the sites in scale
function be achieve with an existing handler? Or would is it possible to implement such a handler?
Thanks!
I guess you want something like Automatic Name Generation — Pyro documentation Please make a feature request for it.
Thank you for such a fast answer!
As far as I understand, the automatic name generation creates a new site, without the same value of the previous one. Is that right?
In my case, I would need to obtain the same value of numpyro.sample("factor")
, i.e., scaled1 would be equal to scaled2 in this toy example.
I though replay handler could do this but in my experiments it didn’t work.
I guess you can do something along the line
with numpyro.handlers.trace() as tr:
scaled1 = scale(x)
with numpyro.handlers.block(), numpyro.handlers.replay(trace=tr):
scaled2 = scale(x)
It works!
But maybe I’ve found a bug, could you kindly confirm if it indeed is an unexpected error?
If I wrap the code with scope handler, I get an assertion error
import jax
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.infer import SVI
from numpyro.infer.autoguide import AutoDelta
def scale(x):
factor = numpyro.sample("factor", dist.HalfNormal(1))
return x * factor
def model(x, obs=None):
with numpyro.handlers.scope(prefix="test/"):
with numpyro.handlers.trace() as tr:
scaled1 = scale(x)
with numpyro.handlers.block(), numpyro.handlers.replay(trace=tr):
scaled2 = scale(x)
numpyro.deterministic("value_scaled1", scaled1)
numpyro.deterministic("value_scaled2", scaled2)
numpyro.sample("obs1", dist.Normal(scaled1, 0.1), obs=obs)
guide = AutoDelta(model)
svi = SVI(
model,
guide=guide,
optim=numpyro.optim.Adam(1),
loss=numpyro.infer.Trace_ELBO(),
)
x = jnp.arange(100)
run_results = svi.run(
rng_key=jax.random.PRNGKey(0),
progress_bar=True,
num_steps=100,
x=x,
obs=x * 2,
)
posterior_samples = guide.sample_posterior(
jax.random.PRNGKey(0), params=run_results.params, x=x
)
The error message:
644 def sample(self, key, sample_shape=()):
--> 645 assert is_prng_key(key)
646 return jnp.abs(self._normal.sample(key, sample_shape))
Without the scope handler it works perfectly
block
does not allow scope
to see the site.
1 Like