@carlossouza I think you can use torch.where
here. We have a similar model in numpyro where we use np.where
.
Edit: I just found an old gist which uses torch.expand
and torch.cat
.
@carlossouza I think you can use torch.where
here. We have a similar model in numpyro where we use np.where
.
Edit: I just found an old gist which uses torch.expand
and torch.cat
.