NaN loss when using custom distribution

Just to introduce the problem that I’m facing:
I have a working implementation of a truncated Dirichlet Process gaussian mixture model, where I use a Beta variational distribution for the Beta random variables in the stick breaking construction. Now, because beta random variables cannot be reparameterized (as far as I know), I was interested in using the Kumaraswamy distribution as an alternative Variational posterior, as they do in this paper -, and comparing performance.

The problem that I have is that whenever I swap in my custom Kumaraswamy distribution for the Beta distribuion in the model, at some point, the ELBO becomes NaN. I tried debugging this bug, but couldn’t make much progress because of my lack of familiarity with pyro’s internals.

Here is the code implementing the Kumaraswamy distribution:

class Kuma(pyro.distributions.TorchDistribution):
    has_rsample = True
    arg_constraints = {"a": constraints.positive, "b": constraints.positive}
    support = constraints.unit_interval

    def __init__(self, a, b):
        self.a, self.b = broadcast_all(a, b)
        self.unif = torch.distributions.Uniform(0, 1)
        super(pyro.distributions.TorchDistribution, self).__init__(
                batch_shape=self.a.shape, event_shape=torch.Size([]))
    def rsample(self, sample_shape=torch.Size([1])):
        u = self.unif.sample(sample_shape)
        return torch.pow(1 - torch.pow(1 - u, 1 / self.b), 1 / self.a)

    def log_prob(self, sample):
        return torch.log(self.a * self.b) + (self.b - 1) * torch.log(1 - 
                torch.pow(sample, self.a)) + (self.a - 1) * torch.log(sample)

    def expand(self, batch_shape, _instance=None):
        new = self._get_checked_instance(Kuma, _instance)
        batch_shape = torch.Size(batch_shape)
        new.a = self.a.expand(batch_shape)
        new.b = self.b.expand(batch_shape)
        new.unif = torch.distributions.Uniform(0, 1)
        super(Kuma, new).__init__(batch_shape=batch_shape)
        return new

And for reference, here is the code implementing the truncated DP mixture:

def stick_breaking(betas):
    K = betas.shape[0]
    sticks = torch.ones(K)

    # stick breaking construction of the distribution
    sticks[0] = betas[0]
    sticks[1:K - 1] = betas[1:K - 1] * torch.cumprod(1 - betas, 0)[:K - 2]
    sticks[-1] = (1 - betas[:K]).prod()

    return sticks

def truncated_dp(data, alpha, K):
    sticks = torch.ones(K)

    with pyro.plate("stick_lengths", size=K):
        betas = pyro.sample("betas", pyro.distributions.Beta(1, alpha))

    with pyro.plate("means", size=K):
        mu = pyro.sample("cluster_means", pyro.distributions.Normal(0, 5))

    sticks = stick_breaking(betas)

    with pyro.plate("obs", size=len(data)):
        cluster_assignments = pyro.sample("assignments", 
        cluster_datapoints = pyro.sample("data", pyro.distributions.Normal(
            mu[cluster_assignments], 1), obs=data)

# variational posterior
def guide(data, alpha, K):
    beta_hyp1 = pyro.param("betas_hyp1", torch.ones(K),
    beta_hyp2 = pyro.param("betas_hyp2", torch.ones(K),
    with pyro.plate("stick_lengths"):
        pyro.sample("betas", pyro.distributions.Beta(beta_hyp1, beta_hyp2))
        # pyro.sample("betas", Kuma(beta_hyp1, beta_hyp2))

    mu_hyp1 = pyro.param("mu_hyp1", torch.zeros(K))

    with pyro.plate("means"):
        pyro.sample("cluster_means", pyro.distributions.Delta(mu_hyp1))

And, here is the code to fit the model:

if __name__ == "__main__":
    K = 10
    alpha = 1

    svi = pyro.infer.SVI(model=truncated_dp,
            optim=pyro.optim.Adam({"lr": .004}),

    # fake data
    data = torch.tensor(np.concatenate([
        np.random.normal(5, 1, size=50),
        np.random.normal(-5, 1, size=50)]), dtype=torch.float32)
    elbo = np.zeros(2000)
    for step in range(2000):
        if step % 100 == 0:
            print("step: {}".format(step))
        elbo[step] = svi.step(data, alpha, K)

Let me know if anything is unclear or you need more information. Thanks!

i ran your code three times and don’t see any NaNs. if you are seeing NaNs it may help to try one of the following:
i) reduce the learning rate
ii) make the beta1 parameter in Adam larger (e.g. 0.95)
iii) clip gradients or parameters

btw fyi the beta distribution can be “reparameterized” and the rsample method is implemented in torch.distributions and is thus available in Pyro (see reference)

Hi Martin,

That was my mistake; in the guide, under the pyro.plate("stick_lengths"), you’ll see that I commented out the Kumaraswamy sample statement, and used a beta instead. If you uncomment the Kuma line and comment out the Beta sample statement, then you should observe NaN’s appearing.

i suspect this is primarily a numerical thing and as such has nothing to do with Pyro as such. e.g. if i invoke torch.set_default_dtype(torch.float64) i do not encounter any NaNs.

possible places to look include:

  • use torch.log1p instead of torch.log(1 ....)
  • instead of using cumprod and the like try to do the entire computation in log space