I’m interested in conditioning on the values of only a subset of variables in a random variable array. The example for condition in the docs goes something like this:
from jax import random
import numpyro
from numpyro.handlers import condition, seed, substitute, trace
import numpyro.distributions as dist
def model():
numpyro.sample('a', dist.Normal(0., 1.))
model = seed(model, random.PRNGKey(0))
exec_trace = trace(condition(model, {'a': -1})).get_trace()
exec_trace['a']['value']
-1
I’m interested in the following generalization to multiple random variables:
from jax import random
import numpyro
from numpyro.handlers import condition, seed, substitute, trace
import numpyro.distributions as dist
def model():
numpyro.sample('a', dist.Normal(0., 1.).expand([2]))
model = seed(model, random.PRNGKey(0))
exec_trace = trace(substitute(model, {'a': [-1, np.nan]})).get_trace()
exec_trace['a']['value']
But the output that I get is:
[-1, nan]
Instead of, for example:
[-1, 0.284742]
Does anyone know how to condition on the values of just some of the variables in an array? Just inserting np.nan doesn’t work and I couldn’t find any examples in the documentation or on the past posts in the forum. Any and all help is very much appreciated.
I think I understand. Essentially, the first sample call puts the observed values into d and the second sample call samples from the distribution for the unobserved sites. Is that right? Should your code read:
That’s right (and thanks for pointing out the typo; i’ve fixed it). In more detail: the first statement sample('a_observed', d, obs=obs) propagates observation information back to the parameters of d, while the second statement sample('a_unobserved', d) draws samples from d, effectively propagating information from d's parameters down to samples.
Is there something like this available for pyro.param, @fritzo? Can I hold the value of one parameter in an array constant during training while learning the others?
@npschafer we don’t currently have a built-in way to mask-freeze pyro.param statements, but I believe you can easily achieve that effect by mask-detaching:
p = pyro.param("p", lambda: my_default_value, constraint=my_constraint)
p = torch.where(mask, p, p.detach())
where p[coord] is learned iff mask[coord] == True. Note you’ll need to create a fresh optimizer instance when you begin this clamped phase of training, to avoid momentum leaking from the phase where all parameters were being learned.
d = dist.Normal(...)
with handlers.mask(mask=obs_mask):
observed = numpyro.sample('a_observed', d, obs=obs)
with handlers.mask(mask=~obs_mask):
unobserved = numpyro.sample('a_unobserved', d)
pyro.deterministic('a', np.where(obs_mask, observed, unobserved))
be expected to give comparable results as compared to this:
with handlers.mask(mask=obs_mask):
observed = numpyro.sample('a_observed', dist.Normal(...), obs=obs)
with handlers.mask(mask=~obs_mask):
unobserved = numpyro.sample('a_unobserved', dist.Normal(...))
pyro.deterministic('a', np.where(obs_mask, observed, unobserved))
assuming the same pyro.params are passed (twice) to both dist.Normal calls in the second case, @fritzo? Or is it important to pass the exact same dist.Normal object in both cases?
The two should behave identically, indeed numpyro uses the former. Note this is available in the latest NumPyro release: numpyro.sample(..., obs_mask=...).