NaN Error when using pyro.plate

Hi,

I have created a model with a custom truncated normal distribution in pyro. In this model I am doing sub-sampling via the pyro.plate context.

→ When I run the forward model (with pyro.plate context) to generate some synthetic data, pyro.sample statement (inside the pyro.plate context) produced NaN values. Code look something like:

def model():
    x = torch.tensor(...)
    with pyro.plate("data"):
        intensity = some_function(x,...)
        observed = pyro.sample('observed', TruncatedNormal(...))
    return observed 

The pyro.sample('observed'....) statement above returns NaN’s.

→ But when I remove the pyro.plate statement, forward model runs fine (no NaNs). Code looks something like:

def model():
    x = torch.tensor(...)
    intensity = some_function(x, ...)
    observed = pyro.sample('observed', TruncatedNormal(...))
    return observed 

In this case, pyro.sample('observed'....) works fine (no NaN’s).

My question is can this be an issue with my custom truncated normal distribution or can it be something related to pyro.plate?

Thanks,
Atharva

Hi @atharvahans, I don’t see why the plate would matter. It can be tricky to implement custom distributions, feel free to post more code. How big is the plate? Can you provide more code details?

wow, that’s a lot of code in the TruncatedNormal distribution. It’s probably best to submit that as a PR to PyTorch (see #32293) and work out bugs through a standard review process and test writing. There are lots of issues: computation in the .__init__() method should be moved to @lazy_propertys, entropy should avoid in-place operation, .__init__() should avoid overriding self.a and self.b since they are used in .support. Gosh this is so much more complicated than Alican’s https://github.com/pytorch/pytorch/pull/32377/files, I wis we had gotten that merged long ago…

1 Like

okay. I will fix the issues that you referred to and submit a PR to PyTorch.

1 Like