How to encode - Censored observations on sums of bernoulli variables

I am trying to model a discrete time process with discrete outputs. At the end of every time step, I get censored observations on the sums of some Bernoulli random variables. The observations are censored in that they cap out if the sum is two or higher, so all I observe is 0, 1, or >=2. I am confused about how to represent this in pyro.

In theory I can code this using pyro.deterministic("obs", min(x1+x2+x3+x4+x5..., 2)) but I can’t pass obs= to a deterministic. I noticed in a reply to this question that I could try using condition on these observations, but it didn’t seem like that user was satisfied with the results. Is this problem a good fit for pyro?

If it matters, my ultimate goal is to learn some latent parameters of the model and forecast it into the future conditional on the observations.

Thanks

1 Like