How to efficiently implement importance sampling?

Hi,

I want to do importance sampling from a distribution with multiple latent variables and one observe statement. My latent are categorical variables and likelihood is a Guassian over a generated image. Here are few things I observed

  1. Its really memory hungry. 40k iterations eat up around 10GB memory. I understand that it keep traces but di it need to keep all variables in memory? For instance, does it also keep observed likelihood in memory, which is large in my case?
  2. I tried running it on cuda and to my surprise it runs slower on cuda. It takes around twice the time on cuda. It might be because of overhead of creating multiple tensors and sending them on GPU in each call. What is the right way of doing it? Is there a batched version of importance sampling?
  3. When I run it on CPU, I can see that it only uses one CPU. However, i think it should use parallization. How should I make it use my multiple CPUs?

Thanks!

Hi, Pyro doesn’t have a vectorized importance sampling implementation at the moment - it draws samples serially and stores them, which is why the performance is so bad. If you post some code we can probably help you speed it up.

Thanks!

Here is my code

img_length = 75 # taken from rela_net script
img_size = img_length*img_length


class ProbModel(nn.Module):
    def __init__(self, size=5, use_cuda=False):
        super(ProbModel, self).__init__()
        # self.encoder = Encoder()
        # self.colors = [
        #     (0, 0, 255),  ##r
        #     (0, 255, 0),  ##g
        #     (255, 0, 0),  ##b
        #     (0, 156, 255),  ##o
        #     (128, 128, 128),  ##k
        #     (0, 255, 255)  ##y
        # ]
        self.colors = [
            (0, 0, 255),  ##r
            (0, 255, 0),  ##g
            # (255, 0, 0),  ##b
            # (0, 156, 255),  ##o
            # (128, 128, 128),  ##k
            # (0, 255, 255)  ##y
        ]
        self.color_names = ['red', 'green', 'blue', 'orange', 'black', 'yellow']
        self.size = size # object size
        if use_cuda:
            self.cuda()
        self.use_cuda = use_cuda

    def deterministic_image(self, idxs, idys, objtypes):
        size = self.size
        colors = self.colors
        color_names = self.color_names
        img = np.ones((img_length, img_length, 3)) * 255
        for color_id, color in enumerate(colors):
            color_name = color_names[color_id]
            idx = idxs[color_id]
            idy = idys[color_id]
            obj_type = objtypes[color_id]

            center = [idx, idy]
            if obj_type == 1:
                start = (center[0] - size, center[1] - size)
                end = (center[0] + size, center[1] + size)
                cv2.rectangle(img, start, end, color, -1)
            else:
                center_ = (center[0], center[1])
                cv2.circle(img, center_, size, color, -1)

        img = img / 255.
        return img

    # define the model p(x|z)p(z)
    def model(self, x, observe=True):
        size = self.size
        colors = self.colors
        with pyro.iarange("batch", x.shape[0], use_cuda=self.use_cuda):
            # with pyro.iarange("colors", len(colors)):
            pixel_probs = x.new_ones(torch.Size((x.shape[0], len(colors), img_length-(2*size))))
            obj_probs = x.new_ones(torch.Size((x.shape[0], len(colors)))) * 0.5

            idxs = size + pyro.sample("idx", dist.Categorical(pixel_probs).independent(1))
            idys = size + pyro.sample("idy", dist.Categorical(pixel_probs).independent(1))
            objtypes = pyro.sample("obj_type", dist.Bernoulli(obj_probs).independent(1))

            img = x.new_zeros(torch.Size((x.shape[0], img_length, img_length, 3)))
            for i in range(x.shape[0]):
                img_i = self.deterministic_image(idxs[i], idys[i], objtypes[i])
                # plt.imshow(img_i)
                # plt.show()
                img[i] = torch.from_numpy(img_i)
            #img -> batch x 75 x 75 x 3 in [0,1]
            img = img.transpose(3,2).transpose(1,2)
            #img -> batch x 3 x 75 x 75 in [0,1]

            img_flat = img.contiguous().view(img.shape[0],-1)
            if observe:
                obs = pyro.sample("obs", dist.Normal(img_flat, 0.1).independent(1),
                            obs=x.reshape(img.size(0), -1))
            return img.detach().cpu().numpy()

    # define the guide (i.e. variational distribution) q(z|x)
    def guide(self, x):
        size = self.size
        colors = self.colors
        # pyro.module("encoder", self.encoder) #IMP STEP################
        with pyro.iarange("batch", x.shape[0], use_cuda=self.use_cuda):
            # with pyro.iarange("colors", len(colors)):
            pixel_probs = pyro.param("idx_param", x.new_ones(torch.Size((x.shape[0], len(colors), img_length-(2*size)))))
            obj_probs = pyro.param("idy_param", x.new_ones(torch.Size((x.shape[0], len(colors)))) * 0.5)

            idxs = pyro.sample("idx", dist.Categorical(pixel_probs).independent(1))
            idys = pyro.sample("idy", dist.Categorical(pixel_probs).independent(1))
            objtypes = pyro.sample("obj_type", dist.Bernoulli(obj_probs).independent(1))
            return idxs, idys, objtypes

loaderTensor = transforms.Compose([
    transforms.ToTensor()])

def load_image(filename):
    image = Image.open(filename).convert('RGB')
    image_var = loaderTensor(image).unsqueeze(0).float()
    return image_var

USE_CUDA=False

target_img = load_image('../data/target.png')
# plt.imshow(target_img.squeeze().transpose(0,1).transpose(1,2))
# plt.show()


probmodel = ProbModel(use_cuda=USE_CUDA)
# sampled_img = probmodel.model(target_img, observe=False)
# plt.imshow(sampled_img.squeeze().transpose(0,1).transpose(1,2))
# plt.imshow(sampled_img.squeeze().swapaxes(0,1).swapaxes(1,2))
# plt.show()

# condition_img = torch.from_numpy(sampled_img)
condition_img = target_img

if USE_CUDA:
    target_img = target_img.cuda()
    condition_img= condition_img.cuda()


num_samples = 60000
start_time = time.time()
posterior = pyro.infer.Importance(probmodel.model, num_samples = num_samples).run(condition_img, False)
marginal_loc = pyro.infer.EmpiricalMarginal(posterior, sites=['idx', 'idy'])
marginal_obj = pyro.infer.EmpiricalMarginal(posterior, sites=['obj_type'])
print('Time taken %s'%(time.time()-start_time))
# Draw samples from marginal
fig = plt.figure()
for rr in range(4):
    [idx], [idy] = marginal_loc()
    [[obj_type]] = marginal_obj()
    inf_img = probmodel.deterministic_image(idx, idy, obj_type)
    ax = fig.add_subplot(2,2,rr+1)
    if rr ==0:
        ax.imshow(condition_img.squeeze().transpose(0,1).transpose(1,2))
    else:
        ax.imshow(inf_img)
plt.suptitle('Samples from no-conditioning empirical marginal - ' + str(num_samples))
fig.set_figheight(7)
plt.show()