dist.Bernoulli recognize the right batch_size, but dist.Categorical won't

Hi all, I’m back again,
Sorry bothering!
the original xs sampling code in model() of ss_vae_M2.py is:

loc = self.decoder.forward([zs, ys])
x = pyro.sample("x", dist.Bernoulli(loc, validate_args=False).to_event(1), obs=xs)

where the shape of loc is [10,200,784], (10 for parallel enumeration, 200 original batch_size, 784 digits) and the dist.Bernoulli is able to realize the batch_size is [10,200] and event_size is [784]:

b = dist.Bernoulli(loc, validate_args=False).to_event(1)
print(f"b batch_shape:{b.batch_shape}")
print(f"b event_shape:{b.event_shape}")

and we have:

b batch_shape:torch.Size([10, 200])
b event_shape:torch.Size([784])

However, when I replace Bernoulli with Categorical, and all the other codes and data are left the same, I got:

loc = self.decoder.forward([zs, ys])
x = pyro.sample("x", dist.Categorical(loc, validate_args=False).to_event(1), obs=xs)
b = dist.Categorical(loc, validate_args=False).to_event(1)
print(f"b batch_shape:{b.batch_shape}")
 print(f"b event_shape:{b.event_shape}")

but the result become:

b batch_shape:torch.Size([10])
b event_shape:torch.Size([200])

Will someone help me figure this out?

I see the original implementaion of Categorical in torch.dist.Categorical, and found it would recoginize the right shapes if “.to_event(1)” not applied:

b = dist.Categorical(loc, validate_args=False)
print(f"b _batch_shape:{b._batch_shape}")
print(f"b _num_events:{b._num_events}")

output:

b _batch_shape:torch.Size([10, 200])
b _num_events:784

Note that the _batch_shape and ‘_num_events’ are original implementations in torch.dist.Categorical, while if we call same variables in pyro wrapper:

b = dist.Categorical(loc, validate_args=False)
print(f"b batch_shape:{b.batch_shape}")
print(f"b event_shape:{b.event_shape}")

the results would be:

b batch_shape:torch.Size([10, 200])
b event_shape:torch.Size([])

which is different from those in Bernoulli distribution:

b = dist.Categorical(loc, validate_args=False).to_event(1)
print(f"b batch_shape:{b.batch_shape}")
print(f"b event_shape:{b.event_shape}")

output:

b batch_shape:torch.Size([10, 200])
b event_shape:torch.Size([784])

Hi @BigDiaos, I think the issue here is that Bernoulli and Categorical accept parameters of different shape:

  • Bernoulli accepts one parameter per instance (per “coin” flipped), whereas
  • Categorical accepts D-many parameters per instance of choice among D categories.

I think if you want to update the tutorial code, you’ll need to add an extra shape (num_categories,) to the right of your loc parameters, requiring changing shapes of the encoder and decoder networks.

Thank you fritzo, fehiepsi helped me figure this out in another topic https://forum.pyro.ai/t/question-about-batch-size-in-the-semi-supervised-vae-demo/4891/9

Following your instructions, I made changes in following codes:

loc = self.decoder_x.forward(thetas)
loc = torch.reshape(loc, (-1, batch_size, num_instances, num_categories))
xs_hat = pyro.sample("x", dist.Categorical(logits=loc, validate_args=False).to_event(1), obs=xs)

from which I

  1. amplified the output of the self.decoder_x by original_dim * num_categories,
  2. reshape the long dim vector into matrix [-1, batch_size, num_instances, num_catgories] and
  3. fed the mat into dist.Categorical

And may I ask for you help if this is a proper way to work this out?

Yes, your snippet looks plausible to me. I think you’ll need to of course update both the encode and decoder, but yes your shapes look ok.

Thanks very much!

Sorry, may I ask an additional question about scale amplification?
Assume this is my model hierarchy:

And the snippet of generating x has been mentioned above:

loc = self.decoder_x.forward(thetas)
loc = torch.reshape(loc, (-1, batch_size, num_instances, num_categories))
xs_hat = pyro.sample("x", dist.Categorical(logits=loc, validate_args=False).to_event(1), obs=xs)

While the prior distribution thetas is sampled as follows:

theta_loc, theta_scale = self.decoder_theta.forward([gs, hs])
thetas = pyro.sample("theta", dist.Normal(theta_loc, theta_scale).to_event(1))

So here is my question: Should amplify the scale of theta_loc, theta_scale and iteratively amplify the scale of their priors?

Can you explain what you mean by “amplify”? Do you mean “share” across instances?

It’s a longer story.

My original model is LDA, a traditional topic model which is inferenced with VI. The scale of each variable has its meanings, such as K topics of \theta.

Then I try it on VAE and realized that some of the encoder/decoder network are too small if I set their scales same with those in the original LDA. Here is the example mentioned above:

def model():
    # gs
    ...
    # hs
    ...
    # theta
    theta_loc, theta_scale = self.decoder_theta.forward([gs, hs])
    thetas = pyro.sample("theta", dist.Normal(theta_loc, theta_scale).to_event(1))
    # x
    x_loc = self.decoder_x.forward(thetas)
    x_loc = torch.reshape(loc, (-1, batch_size, num_instances, num_categories))
    xs_hat = pyro.sample("x", dist.Categorical(logits=x_loc, validate_args=False).to_event(1), obs=xs)

Note that the model is hierarchical, and the scales of \hs, \gs, \theta are all intergers that smaller than 20. However, the scale of x_loc is so large that can be more than 1 thousand (enlarged by original_dim * num_categories ).

Besides, to compare with the original implementations of VAE and SSVAE, the scale of latent embedding z can be 50 or larger.

So I think the bottleneck(the scales of variable embeddings) of my implementation maybe too narrow for the self.decoder_x to reconstruct the data.

:thinking: I think that makes sense if I replace “scale” with “dimension” everywhere, so your theta dimension may be too small to predict the x logits. It seems this might be getting out of the scope of Pyro concerns, but I guess one approach you could take is to avoid high-dimensional categoricals and instead maybe factorize that distribution into a product of lower-dimensional categoricals or Bernoullis? IIUC decoder_x is the LDA embedding?

Yes, your are correct. I miss used the “dimension” with “scale”. Thanks for your advices, I really appreciate!

And I also come up with a new questions about the losses. I’ll explain it here if you don’t mind.

It is:
If I want the model to be able to exactly reproduce a new image with the same order of entries(logits) of the original input image, rather than under the “bag-of-words” assumption. Do I need to manually construct an auxilary mse reconstruction loss to calculate which entry it reproduces right and which is wrong?

I come up with that question is because I think the neural network of decoder_x has already done the same thing. As the num_instances is ordered:

x_loc = torch.reshape(loc, (-1, batch_size, num_instances, num_categories))

And I also don’t know how to set up the mse loss for x as the classify_model() in ss_vae, maybe write a same generate process of x in mse_model() with dummy mse_guide()?

These things just don’t like those in pytorch that are simple and intuitive, and I’m confused about some of the details. I hope you don’t mind :grinning:

Hi BigDiaos, would you mind starting a new forum post to discuss this separate issue, so that other users may more easily discover our discussion?

Sure, I’ll never mind! See this topic Is it possible to build reconstruction loss to learn sequential/positional relations among data entries? - Tutorials - Pyro Discussion Forum

Thank for your help!