I am doing MCMC inference using NUTS.
My model looks basically like this:
def model(A_obs, B_obs):
A = sample('A', Normal(5, 1), obs=A_obs)
weights_mask = sample('weights_mask', Dirichlet(jnp.ones(2)))
mask = sample('mask', Categorical(weights_mask))
# Based on mask I would like to now either calculate B deterministically from A or sample from Exp (element-wise)
scale = sample('scale', HalfNormal(1))
B = sample('B', Exponential(scale), obs=B_obs, mask=(mask==1)) # note: no valid statement
B = deterministic('B', A * val, mask=(mask==0)) # note: no valid statement
The last two lines is what I am trying to come up with.
I am trying to get to B in two ways based on the mask variable:
- if mask == 0: Calculate B deterministically based on A
- if mask == 1: sample and infer B from observed values
I found the EffectHandlers and thought either block or mask should do the trick. But block I think would suppress inference on the entirety of B, which is not what I want. For mask I don’t understand how I would incorporate the sampled mask variable, which also should be inferred.
Thanks in advance!