@JMellor I think with ragged=True
, you will want to replace MultivariateNormal
by Normal
. MVN’s log probability of a vector is a scalar, hence the mask will not work correctly. In addition, inside plate(size=T), MVN(np.zeros(T)).sample() will return a matrix with shape T x T, which, I think, is not what you want. plate(size=T) + Normal(np.zeros(T)) will return a vector with shape T. Finally, MVN is not smart to marginalize out masked variable, i.e. if z = (x, y) ~ MVN
, then we can’t derive MVN.log_prob(y)
from MVN.log_prob(z)
using mask
handler.
1 Like