I need to create real-nvp transform by permute and affine-coupling layers in order to create a guide for SVI in the below toy problem, but permute and affine-coupling cause errors. I got " object has no attribute 'index_select". Can any body please help?
from functools import partial
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import pyro
from pyro import optim
from pyro.distributions import constraints
from pyro.distributions.transforms import iterated,AffineCoupling,affine_coupling,Permute
from pyro.infer import SVI, Trace_ELBO
from pyro.infer.autoguide import AutoNormalizingFlow
from pyro.nn import DenseNN
import pyro.distributions as dist
sample_shape=2
class BananaShaped(dist.TorchDistribution):
support = constraints.real_vector
validate_args = True
def __init__(self):
super(BananaShaped, self).__init__(event_shape=(2,))
def sample(self, sample_shape=()):
return torch.zeros(sample_shape + self.event_shape)
def log_prob(self, x):
term1 = 0.5 * ((torch.linalg.norm(x, axis=-1) - 2) / 0.4) ** 2
term2 = -0.5 * ((x[..., :1] + torch.tensor([-2., 2.])) / 0.6) ** 2
pe = term1 - torch.logsumexp(term2, axis=-1)
return -pe
def model():
pyro.sample("x", BananaShaped())
transform1 = dist.transforms.ComposeTransform([affine_coupling,Permute(torch.randperm(2, dtype=torch.long)),affine_coupling])
transform2=dist.transforms.ComposeTransform([Permute(torch.randperm(2, dtype=torch.long))])
guide = AutoNormalizingFlow(model, transform2)
def fit_guide(guide):
pyro.clear_param_store()
adam = optim.Adam({"lr": 0.01})
svi = SVI(model, guide, adam, Trace_ELBO())
for i in range(5000):
loss = svi.step()
if i % 250 == 0:
print(loss)
fit_guide(guide)
with pyro.plate("N", 2000):
guide_samples = guide()["x"].detach().cpu().numpy()
plt.figure(16)
samples = guide_samples
sns.scatterplot(samples[:, 0], samples[:, 1])
plt.show()