Bernoulli does not sample in the amount of batch_size

Hi,
I am trying to follow SS-VAE example of Pyro, and sample a 'y' from a Bernoulli distribution whenever my dataset does not have a label. However, Bernoulli does not sample. I posted this question to the Pyro GitHub by thinking that it is a bug, however it was rather an enumeration problem. Therefore I wanted to discuss it in here.
In detail:

    def guide(self,xs,ys=None):
        batch_size = xs.size(0)
        assert (batch_size == 20)
        if ys is None:
          with pyro.plate('pore',batch_size):
            assert (batch_size == 20)
            alpha = self.encoder_y.forward(xs) # we get the alpha from encoder_y
            print("alpha shape from the encoder_y", alpha.shape)
            ys = pyro.sample("y", dist.Bernoulli(alpha))
            m = dist.Bernoulli(alpha)
            print("Bernoulli event shape",m.event_shape)
            print("Bernoulli batch shape is",m.batch_shape)
            print("y from the encoder_y Bernoulli", ys.shape)
            print("y from the encoder_y Bernoulli is", ys)

And the output was:

alpha shape from the encoder_y torch.Size([20, 1])
Bernoulli event shape torch.Size([])
Bernoulli batch shape is torch.Size([20, 1])
y from the encoder_y Bernoulli torch.Size([2, 1])
y from the encoder_y Bernoulli is tensor([[0.],
        [1.]])

And my SVI was:
svi = SVI(ssvae.model, config_enumerate(ssvae.guide), optimizer, loss=TraceEnum_ELBO(max_plate_nesting=1, strict_enumeration_warning=False))

I see that the Bernoulli was giving the possible outputs: 0 and 1. But it was not sampling.

Then, I removed enumerate from the guide, so the SVI is:
svi = SVI(ssvae.model, ssvae.guide, optimizer, loss=TraceEnum_ELBO(max_plate_nesting=1, strict_enumeration_warning=False))

The output now is:
alpha shape from the encoder_y torch.Size([20, 1])
y shape from the Bernoulli torch.Size([20, 20])

However, I do not understand why y.shape is not still the same as the alpha, which is the parameter Bernoulli uses to sample y.
Shouldn’t they need to have the same shape?

I appreciate your input!

If your intention is that there should be one binary label for each datapoint, then you need the shape of alpha to correspond to the desired batch shape at site y, which is (20,) (as opposed to (20, 1), the current shape of alpha in your code) because there is one plate of size 20.

Please see the tensor shapes in Pyro tutorial for more background and tips.

Thank you @eb8680_2, I followed the tutorial and changed my alpha to [20]. However it leads me to a Bernoulli one samples properly (20 times binary label for each datapoint) in the model, and one does not in the guide.
I dont understand why in one Bernoulli I can see the picked points(0 or 1), and in the other I cannot.
(I followed the enumeration tutorial and used infer method in sampling in the guide to internalize out discrete variables)

def model(self,xs,ys=None):
    pyro.module("ss_vae",self)
    with pyro.plate('pore',20):
      batch_size = xs.size(0)
      z_mean = torch.zeros(batch_size, self.z_dim)
      print("z_mean has a shape of", z_mean.shape)
      z_standdev = torch.ones(batch_size, self.z_dim)
      # sample from a gaussian
      zs = pyro.sample("z", dist.Normal(z_mean,z_standdev).to_event(1))
      print("shape of z after being sampled in the model is", zs.shape)
 
      assert (batch_size == 20)
      alpha_known = torch.ones(batch_size, self.y_dim) / (1.0 * self.y_dim)
      alpha_known = alpha_known.squeeze(1)
      assert (alpha_known.shape == (20,))
      print("alpha known for observed y has the shape of", alpha_known.shape)
      ys = pyro.sample("y", dist.Bernoulli(probs=alpha_known),obs=ys)

      print("supervised y looks like:", ys)
      print("supervised y has a shape of", ys.shape)
      #decode z and y then
      re_input,re_scale = self.decoder.forward(ys,zs)
      pyro.sample("x", dist.Normal(re_input,re_scale).to_event(1), obs=xs)
      return re_input,re_scale


  def guide(self,xs,ys=None):
      with pyro.plate('pore', 20):
      if ys is None:
        alpha = self.encoder_y.forward(xs)
        assert (alpha.shape == (20,))
        print("alpha shape from the encoder_y", alpha.shape)
        ys = pyro.sample("y", dist.Bernoulli(alpha),infer={"enumerate": "parallel"})
        d = dist.Bernoulli(alpha)
        print("batch shape",d.batch_shape)
        print("log prob",d.log_prob(ys).shape)
      print("y shape from the Bernoulli", ys.shape)
      print("ys is", ys)

      loc,scale = self.encoder_z.forward(xs,ys)
      pyro.sample("z",dist.Normal(loc,scale).to_event(1))

And from model, (supervised y ) y.shape is [20,] however from guide, it is still [2,1].

supervised y looks like: tensor([1., 1., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1.])
supervised y has a shape of torch.Size([20])
alpha shape from the encoder_y torch.Size([20])
batch shape torch.Size([20])
log prob torch.Size([2, 20])
y shape from the Bernoulli torch.Size([2, 1])
ys is tensor([[0.],
        [1.]])
encoder_z has an x of  torch.Size([20, 28])
encoder_z has an y of  torch.Size([2, 1])

I see that y from guide is [2,1] but it refers to [2,20] in fact (due to #ssvae.guide = config_enumerate(ssvae.guide,"parallel", expand=True) # which gives 2,20 as we expand)
But it is still not sampled values, they are just the possible outputs…

Don’t they need to come out in the same size as it is the same distribution they are sampled from?

I use TracenumELBO in SVI as well.
svi = SVI(ssvae.model, config_enumerate(ssvae.guide), optimizer, loss=TraceEnum_ELBO(strict_enumeration_warning=True))

I appreciate your input!

I can’t tell from your latest snippet why the shapes are different. What is the code that executes model and guide and causes those messages to be printed? Why are the messages from the model appearing earlier than the messages from the guide in your logged output?

Note that if you’re providing an observed value ys for y when running the model, then the output of sample site y in the model will be that value, and will therefore have shape (20,).

Sorry @eb8680_2, I took some time to read the docs for tensor shape and enumeration again.
And came to this solution, which I can train with now, however I am not very sure if my thinking is right.
I will share my code with you here so that I can have your feedback properly.

My goal is to replicate the ‘SS-VAE’ with MNIST dataset with my own dataset, where my labels are 0 or 1 or None ( making up the semi supervised model ). Therefore I want to use a Bernoulli to sample for my label when it is None, and this is making it a discrete enumeration issue.

The only way I can reach to enumerated values of Bernoulli is when I use to_event(1). I have my alpha ( the parameter for Bernoulli ) in shape of [20,1], so when I use to_event(1), the batch_shape becomes 20, and event_shape becomes 1.
Here comes my first confusion:

  1. I use to_event(1) in Bernoulli sampling, otherwise I cannot feed its output to encoder_z which needs the output of Bernoulli sampling. (Encoder_z and Decoder needs input of [20,1] for y).
    Without to_event(1), Bernoulli gives [2,1] due to broadcasting, which I cant use as input for encoder_z.

My model and guide are as follows:

  def model(self,xs,ys=None):
    pyro.module("ss_vae",self)
    with pyro.plate('my_plate',dim=-1):
      batch_size = xs.size(0) #which will be 20

      z_mean = torch.zeros(batch_size, self.z_dim)
      print("z_mean has a shape of", z_mean.shape)
      z_standdev = torch.ones(batch_size, self.z_dim)

      # sample from a gaussian
      zs = pyro.sample("z", dist.Normal(z_mean,z_standdev).to_event(1))
      print("shape of z after being sampled in the model is", zs.shape)
      assert (batch_size == 20)

      alpha_known =  torch.Tensor(0.5 * np.ones((20,1))).view(-1)
      print("alpha known for observed y has the shape of", alpha_known.shape)
      ys = pyro.sample("y", dist.Bernoulli(probs=alpha_known).to_event(1),obs=ys)
      print("supervised y looks like:", ys)

      #decode z and y then
      print("decoder will use a supervised y with a shape of", ys.shape)
      re_input,re_scale = self.decoder.forward(ys,zs)
      pyro.sample("x", dist.Normal(re_input,re_scale).to_event(1), obs=xs)

      return re_input,re_scale

And the guide:

  def guide(self,xs,ys=None):
    with pyro.plate('my_plate', dim =-1):
      if ys is None:
        alpha = self.encoder_y.forward(xs)
        alpha = alpha.squeeze(1)
        print("alpha shape from the encoder_y", alpha.shape)
        assert (alpha.shape == (20,))
        ys = pyro.sample("y", dist.Bernoulli(alpha).to_event(1))
        print("Bernoulli to_event y shape is", ys.shape)
        print("Bernoulli to_event y values are", ys)
      print("encoder z uses a guide y of shape is", ys.shape)
      loc,scale = self.encoder_z.forward(xs,ys)
      pyro.sample("z",dist.Normal(loc,scale).to_event(1))

The outputs and shapes I receive are:

z_mean has a shape of torch.Size([20, 12])
shape of z after being sampled in the model is torch.Size([20, 12])
alpha known for observed y has the shape of torch.Size([20])
supervised y looks like: tensor([1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0.,
        1., 1.])
decoder will use a supervised y with a shape of torch.Size([20])
decoder y shape torch.Size([20])
decoder z shape torch.Size([20, 12])

alpha shape from the encoder_y torch.Size([20])
Bernoulli to_event y shape is torch.Size([20])
Bernoulli to_event y values are tensor([1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0.,
        1., 1.])
encoder z uses a guide y of shape is torch.Size([20])

However to make it happen, I also need to a shape control in the encoder_z:

class Encoder_z(nn.Module):
  def __init__(self,z_dim, hidden_dim,hidden_dim_mid,hidden_dim_bef):
    super().__init__()
    self.encoderzl1 = nn.Linear(input_size+output_size,hidden_dim)
    self.encoderzl2 = nn.Linear(hidden_dim,hidden_dim_mid)
    ...

  def forward(self,x,y):
    # Check dimension of y so this can be used with and without enumeration
    if y.dim() < 2:
      y = y.unsqueeze(1)
    data_combined = torch.cat((x,y), 1)
    ...

I use config_enumerate in my guide:

    ssvae = SSVAE()
    infer={"enumerate": "parallel"}
    ssvae.guide = config_enumerate(ssvae.guide,"parallel", expand=False) 
    pyro.clear_param_store()
    svi = SVI(ssvae.model, ssvae.guide, optimizer, loss=TraceEnum_ELBO(max_plate_nesting=1,strict_enumeration_warning=True))
  1. Then, in this case, I can manage to do enumeration with to_event(1) but then are my Bernoulli’s in model and guide still independent in their sampling? ( because they are under pyro.plate) Because I need each Bernoulli sampling to be independent from one another…( a new alpha = a new choice of Bernoulli )

  2. Or do I need to say pyro.plate('my_plate', 20) to declare independence for 20 values? If I apply pyro.plate('my_plate', 20) and remove to_event(1) from Bernoullis in model and guide (both), then Bernoulli starts to give [2,1].
    If I keep both pyro.plate('my_plate', 20) and to_event(1) in model and guide, then the shape of Bernoulli is [20,20] which is not what I want.

I would appreciate to hear back from you when you have time, as I am very confused with the dependencies and why to_event(1) gives a vector and plate does not do any effect… I am quite stuck here…

Thank you and I apologize for the super long question!