Condition on the values of part of an array of random variables

I’m interested in conditioning on the values of only a subset of variables in a random variable array. The example for condition in the docs goes something like this:

from jax import random
import numpyro
from numpyro.handlers import condition, seed, substitute, trace
import numpyro.distributions as dist

def model():
    numpyro.sample('a', dist.Normal(0., 1.))

model = seed(model, random.PRNGKey(0))
exec_trace = trace(condition(model, {'a': -1})).get_trace()
exec_trace['a']['value']
-1

I’m interested in the following generalization to multiple random variables:

from jax import random
import numpyro
from numpyro.handlers import condition, seed, substitute, trace
import numpyro.distributions as dist

def model():
    numpyro.sample('a', dist.Normal(0., 1.).expand([2]))

model = seed(model, random.PRNGKey(0))
exec_trace = trace(substitute(model, {'a': [-1, np.nan]})).get_trace()
exec_trace['a']['value']

But the output that I get is:

[-1, nan]

Instead of, for example:

[-1, 0.284742]

Does anyone know how to condition on the values of just some of the variables in an array? Just inserting np.nan doesn’t work and I couldn’t find any examples in the documentation or on the past posts in the forum. Any and all help is very much appreciated.

Hi @npshafter, if you don’t need to sample the remaining values of an array, then you can simply use a single mask, something like

with handlers.mask(mask=obs_mask):
    numpyro.sample('a', dist.Normal(...), obs=obs)

but you should still set valid values for the unobserved obs entries to avoid NAN pollution.

If you do need to sample the remaining values then you can use three statements as in

d = dist.Normal(...)
with handlers.mask(mask=obs_mask):
    observed = numpyro.sample('a_observed', d, obs=obs)
with handlers.mask(mask=~obs_mask):
    unobserved = numpyro.sample('a_unobserved', d)
numpyro.deterministic('a', np.where(obs_mask, observed, unobserved))

We’ll probably be adding this as a simplified syntax in the next release (Pyro PR):

numpyro.sample('a', obs=obs, obs_mask=obs_mask)

EDIT fix code typos

1 Like

Hi @fritzo; thanks for the prompt and positive response. While waiting for someone to reply, I came up with this:

from jax import random
import numpyro
from numpyro.handlers import condition, seed, substitute, trace
import numpyro.distributions as dist

def model():
    a_unobserved = numpyro.sample('a_unobserved', dist.Normal(0., 1.))
    a = numpyro.deterministic('a', jnp.hstack([1, a_unobserved]))
    

model = seed(model, random.PRNGKey(0))
exec_trace = trace(model).get_trace()
exec_trace['a']['value']
DeviceArray([ 1.       , -1.2515389], dtype=float32)

Can you explain to me the difference between your solution and this one?

Much appreciated!
Nick

I think I understand. Essentially, the first sample call puts the observed values into d and the second sample call samples from the distribution for the unobserved sites. Is that right? Should your code read:

observed = numpyro.sample('a_observed', d, obs=obs)

and likewise for unobserved?

That’s right (and thanks for pointing out the typo; i’ve fixed it). In more detail: the first statement sample('a_observed', d, obs=obs) propagates observation information back to the parameters of d, while the second statement sample('a_unobserved', d) draws samples from d, effectively propagating information from d's parameters down to samples.

1 Like

Great - thanks for taking the time to explain it.

Is there something like this available for pyro.param, @fritzo? Can I hold the value of one parameter in an array constant during training while learning the others?

@npschafer we don’t currently have a built-in way to mask-freeze pyro.param statements, but I believe you can easily achieve that effect by mask-detaching:

p = pyro.param("p", lambda: my_default_value, constraint=my_constraint)
p = torch.where(mask, p, p.detach())

where p[coord] is learned iff mask[coord] == True. Note you’ll need to create a fresh optimizer instance when you begin this clamped phase of training, to avoid momentum leaking from the phase where all parameters were being learned.

1 Like

Yes, seem to be working. Cheers.

Would this:

d = dist.Normal(...)
with handlers.mask(mask=obs_mask):
    observed = numpyro.sample('a_observed', d, obs=obs)
with handlers.mask(mask=~obs_mask):
    unobserved = numpyro.sample('a_unobserved', d)
pyro.deterministic('a', np.where(obs_mask, observed, unobserved))

be expected to give comparable results as compared to this:

with handlers.mask(mask=obs_mask):
    observed = numpyro.sample('a_observed', dist.Normal(...), obs=obs)
with handlers.mask(mask=~obs_mask):
    unobserved = numpyro.sample('a_unobserved', dist.Normal(...))
pyro.deterministic('a', np.where(obs_mask, observed, unobserved))

assuming the same pyro.params are passed (twice) to both dist.Normal calls in the second case, @fritzo? Or is it important to pass the exact same dist.Normal object in both cases?

The two should behave identically, indeed numpyro uses the former. Note this is available in the latest NumPyro release: numpyro.sample(..., obs_mask=...).

1 Like

Great; thanks for confirming.