Hi,
I have created a model with a custom truncated normal distribution in pyro. In this model I am doing sub-sampling via the pyro.plate context.
→ When I run the forward model (with pyro.plate context) to generate some synthetic data, pyro.sample statement (inside the pyro.plate context) produced NaN values. Code look something like:
def model():
x = torch.tensor(...)
with pyro.plate("data"):
intensity = some_function(x,...)
observed = pyro.sample('observed', TruncatedNormal(...))
return observed
The pyro.sample('observed'....)
statement above returns NaN’s.
→ But when I remove the pyro.plate statement, forward model runs fine (no NaNs). Code looks something like:
def model():
x = torch.tensor(...)
intensity = some_function(x, ...)
observed = pyro.sample('observed', TruncatedNormal(...))
return observed
In this case, pyro.sample('observed'....)
works fine (no NaN’s).
My question is can this be an issue with my custom truncated normal distribution or can it be something related to pyro.plate?
Thanks,
Atharva