Just to introduce the problem that I’m facing:
I have a working implementation of a truncated Dirichlet Process gaussian mixture model, where I use a Beta variational distribution for the Beta random variables in the stick breaking construction. Now, because beta random variables cannot be reparameterized (as far as I know), I was interested in using the Kumaraswamy distribution as an alternative Variational posterior, as they do in this paper - [1605.06197] Stick-Breaking Variational Autoencoders, and comparing performance.
The problem that I have is that whenever I swap in my custom Kumaraswamy distribution for the Beta distribuion in the model, at some point, the ELBO becomes NaN. I tried debugging this bug, but couldn’t make much progress because of my lack of familiarity with pyro’s internals.
Here is the code implementing the Kumaraswamy distribution:
class Kuma(pyro.distributions.TorchDistribution):
has_rsample = True
arg_constraints = {"a": constraints.positive, "b": constraints.positive}
support = constraints.unit_interval
def __init__(self, a, b):
self.a, self.b = broadcast_all(a, b)
self.unif = torch.distributions.Uniform(0, 1)
super(pyro.distributions.TorchDistribution, self).__init__(
batch_shape=self.a.shape, event_shape=torch.Size([]))
def rsample(self, sample_shape=torch.Size([1])):
u = self.unif.sample(sample_shape)
return torch.pow(1 - torch.pow(1 - u, 1 / self.b), 1 / self.a)
def log_prob(self, sample):
return torch.log(self.a * self.b) + (self.b - 1) * torch.log(1 -
torch.pow(sample, self.a)) + (self.a - 1) * torch.log(sample)
def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(Kuma, _instance)
batch_shape = torch.Size(batch_shape)
new.a = self.a.expand(batch_shape)
new.b = self.b.expand(batch_shape)
new.unif = torch.distributions.Uniform(0, 1)
super(Kuma, new).__init__(batch_shape=batch_shape)
return new
And for reference, here is the code implementing the truncated DP mixture:
def stick_breaking(betas):
K = betas.shape[0]
sticks = torch.ones(K)
# stick breaking construction of the distribution
sticks[0] = betas[0]
sticks[1:K - 1] = betas[1:K - 1] * torch.cumprod(1 - betas, 0)[:K - 2]
sticks[-1] = (1 - betas[:K]).prod()
return sticks
@pyro.infer.config_enumerate
def truncated_dp(data, alpha, K):
sticks = torch.ones(K)
with pyro.plate("stick_lengths", size=K):
betas = pyro.sample("betas", pyro.distributions.Beta(1, alpha))
with pyro.plate("means", size=K):
mu = pyro.sample("cluster_means", pyro.distributions.Normal(0, 5))
sticks = stick_breaking(betas)
with pyro.plate("obs", size=len(data)):
cluster_assignments = pyro.sample("assignments",
pyro.distributions.Categorical(sticks))
cluster_datapoints = pyro.sample("data", pyro.distributions.Normal(
mu[cluster_assignments], 1), obs=data)
# variational posterior
def guide(data, alpha, K):
beta_hyp1 = pyro.param("betas_hyp1", torch.ones(K),
constraint=constraints.positive)
beta_hyp2 = pyro.param("betas_hyp2", torch.ones(K),
constraint=constraints.positive)
with pyro.plate("stick_lengths"):
pyro.sample("betas", pyro.distributions.Beta(beta_hyp1, beta_hyp2))
# pyro.sample("betas", Kuma(beta_hyp1, beta_hyp2))
mu_hyp1 = pyro.param("mu_hyp1", torch.zeros(K))
with pyro.plate("means"):
pyro.sample("cluster_means", pyro.distributions.Delta(mu_hyp1))
And, here is the code to fit the model:
if __name__ == "__main__":
K = 10
alpha = 1
pyro.clear_param_store()
svi = pyro.infer.SVI(model=truncated_dp,
guide=guide,
optim=pyro.optim.Adam({"lr": .004}),
loss=pyro.infer.TraceEnum_ELBO(num_particles=1))
# fake data
data = torch.tensor(np.concatenate([
np.random.normal(5, 1, size=50),
np.random.normal(-5, 1, size=50)]), dtype=torch.float32)
elbo = np.zeros(2000)
for step in range(2000):
if step % 100 == 0:
print("step: {}".format(step))
elbo[step] = svi.step(data, alpha, K)
Let me know if anything is unclear or you need more information. Thanks!