Masked sampling/deterministically calculating

I am doing MCMC inference using NUTS.

My model looks basically like this:

def model(A_obs, B_obs):
  A = sample('A', Normal(5, 1), obs=A_obs)
  
  weights_mask = sample('weights_mask', Dirichlet(jnp.ones(2)))
  mask = sample('mask', Categorical(weights_mask))
  
  # Based on mask I would like to now either calculate B deterministically from A or sample from Exp (element-wise)
  scale = sample('scale', HalfNormal(1))
  B = sample('B', Exponential(scale), obs=B_obs, mask=(mask==1)) # note: no valid statement
  B = deterministic('B', A * val, mask=(mask==0)) # note: no valid statement

The last two lines is what I am trying to come up with.
I am trying to get to B in two ways based on the mask variable:

  • if mask == 0: Calculate B deterministically based on A
  • if mask == 1: sample and infer B from observed values

I found the EffectHandlers and thought either block or mask should do the trick. But block I think would suppress inference on the entirety of B, which is not what I want. For mask I don’t understand how I would incorporate the sampled mask variable, which also should be inferred.

Thanks in advance!

I guess you can do something like this

B_mask = sample('B_mask', Exponential(scale).mask(mask), obs=B_obs)
B = deterministic('B', jnp.where(mask, B_mask, A * val))

Hey @fehiepsi!

Thanks for you help. Unfortunately I couldn’t get the code running.
I received this error:

ValueError: Missing a plate statement for batch dimension -1 at site ‘A’. You can use numpyro.util.format_shapes utility to check shapes at all sites of your model.

The model:

def model(A_obs, B_obs):
    A = sample('A', Normal(5, 1), obs=A_obs)

    weights_mask = sample('weights_mask', Dirichlet(jnp.ones(2)))
    mask = sample('mask', Categorical(weights_mask))

    scale = sample('scale', HalfNormal(1))
    B_mask = sample('B_mask', Exponential(scale).mask(mask), obs=B_obs)
    B = deterministic('B', jnp.where(mask, B_mask, A * val))

Checking the trace shapes gives:

race Shapes:         
     Param Sites:         
    Sample Sites:         
           A dist      |  
            value 1000 |  
weights_mask dist      | 2
            value      | 2
        mask dist      |  
            value      |  
       scale dist      |  
            value      |  
      B_mask dist      |  
            value 1000 |  

Further I received a FutureWarning, that enumerated sites need to be marked with infer={'enumerate': 'parallel'}. I then changed the mask sample site to:

    mask = sample('mask', Categorical(weights_mask), infer={'enumerate': 'parallel'})

which removed the FutureWarning.
The actual error though stayed.

Maybe some additional information:
I actually would rather like to sample the deterministic B from a Normal with mean(A*val) and some noise as variance. Here the problem is though, that A*val is an array of length(A).

Thanks in advance. If you need any further information please, let me know!