Question about batch_size in the semi-supervised VAE demo

Hi, all!
I’m running the semi-supervised VAE demo, link: The Semi-Supervised VAE — Pyro Tutorials 1.8.4 documentation. And the batch_size set by default is batch_size=200, with mnist pictures size 28*28=784, so the input mat in every training epoch kind like xs.shape=[200, 784].

However, when I try to print out the shapes of tensors within the model() and guide(), the shapes of the matrix xs sometimes is [10, 200, 784], and sometimes is [200, 784].

Will someone give me a hint about from where can I get the tutorial about this? Thanks a lot!

The additional dimension 10 is the enumerated dimension - please see the Interlude: Summing Out Discrete Latents section. You can take a look at the enumeration tutorial as a good start.

Hi, fehiepsi, thanks very much!
I read the enumeration tutorial and realized the addtional dimension in [10, 200, 10] is to enumerate 10 classes of hand written digits for discrete variable ys=[200, 10] (batch_size=200).

But I’ve got another question, my I ask if you could help me figure it out? Here is my code:

I first define a Logistic norm encoder to replace dirichlet prior

class LN_Encoder(nn.Module):
	# Base class for the encoder net, used in the guide
	# Use logistic norm to replace Dirichlet
	def __init__(self, input_size, var_size, hidden, dropout):
		super().__init__()
		assert type(hidden) == int, "We got only one hidden layer in LN_Encoder, so thelayer size must be int"
		self.drop = nn.Dropout(dropout)  # to avoid component collapse
		self.fc1 = nn.Linear(input_size, hidden)
		self.fc2 = nn.Linear(hidden, hidden)
		self.fcmu = nn.Linear(hidden, var_size)
		self.fclv = nn.Linear(hidden, var_size)
		# NB: here we set `affine=False` to reduce the number of learning parameters
		# See https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm1d.html
		# for the effect of this flag in BatchNorm1d
		self.bnmu = nn.BatchNorm1d(var_size, affine=False)  # to avoid component collapse
		self.bnlv = nn.BatchNorm1d(var_size, affine=False)  # to avoid component collapse

	def forward(self, inputs):
		h = F.softplus(self.fc1(inputs))
		h = F.softplus(self.fc2(h))
		h = self.drop(h)
		# μ and Σ are the outputs
		logtheta_loc = self.bnmu(self.fcmu(h))
		logtheta_logvar = self.bnlv(self.fclv(h))
		logtheta_scale = (0.5 * logtheta_logvar).exp()  # Enforces positivity
		return logtheta_loc, logtheta_scale

Then I follow the orginal ss_vae code to define my model:

from utils.custom_mlp import MLP, Exp, ConcatModule 
class myVAE(nn.Module):
    def setup_networks(self):
        self.decoder_x = MLP(
			[self.theta_dim] + hidden_sizes + [self.args.num_voxels_per_robo],
			activation=nn.Softplus,
			output_activation=nn.Sigmoid,
			allow_broadcast=self.allow_broadcast,
			use_cuda=self.use_cuda,
		)
        self.decoder_theta = MLP(...)
        self.encoder_theta = MLP(...)
        self.encoder_y = MLP(...)
        self.encoder_g = LN_Encoder(...)
    def model(self, xs, hs=None, tou=torch.tensor(1.0)):
        pyro.module("my_vae", self)
        bs = xs.size(0)	# batch size
        options = dict(dtype=xs.dtype, device=xs.device)
        with pyro.plate("data"):
            # ys
            alpha_prior = torch.ones(bs, self.h_dim, **options) / (1.0 * self.h_dim)
            ys = pyro.sample("y", dist.RelaxedOneHotCategorical(probs=alpha_prior, temperature=tou, validate_args=False), obs=ys)
            # g
            log_g_loc = torch.zeros(bs, self.g_dim, **options)
            log_g_scale = torch.ones(bs, self.g_dim, **options)
            log_g = pyro.sample("log_g", dist.Normal(log_g_loc, log_g_scale).to_event(1))
            gs = F.softmax(log_g, -1)
            # theta
            theta_loc, theta_scale = self.decoder_theta.forward([gs, ys])  # refer to ladder or hierarchical vae
            thetas = pyro.sample("theta", dist.Normal(theta_loc, theta_scale).to_event(1))
            # sample x
            loc = self.decoder_x.forward(thetas)
            xs_hat = pyro.sample("x", dist.Categorical(loc, validate_args=False).to_event(1), obs=xs)
    def guide(self, xs, ys=None, tou=torch.tensor(1.0)):
        with pyro.plate("data"):
            if hs is None:
                alpha = self.encoder_y.forward(xs)
                ys = pyro.sample("y", dist.RelaxedOneHotCategorical(probs=alpha, temperature=tou, validate_args=False))
            log_g_loc, log_g_scale = self.encoder_g(torch.cat([xs, ys], dim=-1))
            log_g = pyro.sample("log_g", dist.Normal(log_g_loc, log_g_scale).to_event(1))
            gs = F.softmax(log_g, -1)
            theta_loc, theta_scale = self.encoder_theta.forward([xs, gs])
            thetas = pyro.sample("theta", dist.Normal(theta_loc, theta_scale).to_event(1))

Then I’ve encountered two problems:

  1. Compared with the original ss_vae, there is no auto enumeration. I take this as there is no discrete distribution, except sampling x:
xs_hat = pyro.sample("x", dist.Categorical(loc, validate_args=False).to_event(1), obs=xs)
  1. the event shape for dist.Categorical is unexpected:
loc = self.decoder_x.forward(thetas)
print(f"loc shape:{loc.shape}")
a = dist.Categorical(loc, validate_args=False).to_event(1)
print(f"a batch_shape:{a.batch_shape}")
print(f"a event_shape:{a.event_shape}")

I used “.to_event(1)” to ensure the event dimension, but got:

loc shape:torch.Size([200, 784])
a batch_shape:torch.Size([])
a event_shape:torch.Size([200])

So the dimension sampled xs_hat is [200], not [784], and when I change it to “to_event(0)”, the result would be:

loc shape:torch.Size([200, 784])
a batch_shape:torch.Size([200])
a event_shape:torch.Size([])

No event dimensions, while the batch_shape is right. And the size of xs_hat is also right:

xs_hat = pyro.sample("x", dist.Categorical(loc, validate_args=False).to_event(1), obs=xs)
print(f"xs_hat shape:{xs_hat.shape}")

and we have:

xs_hat shape:torch.Size([200, 784])

However, still in the model(), the results for

b = dist.Normal(theta_loc, theta_scale).to_event(1)
print(f"b batch_shape:{b.batch_shape}")
print(f"b event_shape:{b.event_shape}")

are right, and I got:

b batch_shape:torch.Size([200])
b event_shape:torch.Size([8])

I think this weired problem may be the cause of mismatch error in broadcasting tensors:

Traceback (most recent call last):
  File "D:\Programming\VAE\myLDA\train_myvae_v0.2.py", line 363, in <module>
    main(EXAMPLE_RUN)
  File "D:\Programming\VAE\myLDA\train_myvae_v0.2.py", line 184, in main
    epoch_losses_sup, epoch_losses_unsup = run_inference_for_epoch(
  File "D:\Programming\VAE\myLDA\train_myvae_v0.2.py", line 70, in run_inference_for_epoch
    new_loss = losses[loss_id].step(xs)
  File "C:\Users\bigya\anaconda3\envs\pyroEnv\lib\site-packages\pyro\infer\svi.py", line 145, in step
    loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
  File "C:\Users\bigya\anaconda3\envs\pyroEnv\lib\site-packages\pyro\infer\traceenum_elbo.py", line 451, in loss_and_grads
    for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs):
  File "C:\Users\bigya\anaconda3\envs\pyroEnv\lib\site-packages\pyro\infer\traceenum_elbo.py", line 394, in _get_traces
    yield self._get_trace(model, guide, args, kwargs)
  File "C:\Users\bigya\anaconda3\envs\pyroEnv\lib\site-packages\pyro\infer\traceenum_elbo.py", line 339, in _get_trace
    model_trace, guide_trace = get_importance_trace(
  File "C:\Users\bigya\anaconda3\envs\pyroEnv\lib\site-packages\pyro\infer\enum.py", line 75, in get_importance_trace
    model_trace.compute_log_prob()
  File "C:\Users\bigya\anaconda3\envs\pyroEnv\lib\site-packages\pyro\poutine\trace_struct.py", line 230, in compute_log_prob
    log_p = site["fn"].log_prob(
  File "C:\Users\bigya\anaconda3\envs\pyroEnv\lib\site-packages\torch\distributions\independent.py", line 99, in log_prob
    log_prob = self.base_dist.log_prob(value)
  File "C:\Users\bigya\anaconda3\envs\pyroEnv\lib\site-packages\pyro\distributions\torch.py", line 141, in log_prob
    return super().log_prob(value)
  File "C:\Users\bigya\anaconda3\envs\pyroEnv\lib\site-packages\torch\distributions\categorical.py", line 125, in log_prob
    value, log_pmf = torch.broadcast_tensors(value, self.logits)
  File "C:\Users\bigya\anaconda3\envs\pyroEnv\lib\site-packages\torch\functional.py", line 74, in broadcast_tensors
    return _VF.broadcast_tensors(tensors)  # type: ignore[attr-defined]
RuntimeError: The size of tensor a (784) must match the size of tensor b (200) at non-singleton dimension 1

Can you please help me figure this out? Cause I kinda spire multiple days for debugging these errors.
Thanks in advance!

Because the probs loc has shape (200, 784), the Categorical batch shape will be 200 and event shape is (). If we are given a fair dice, its prob will has shape 6. Each time you throw it, you will get a number (not 6 numbers). I think your decoder needs to return some thing with shape (200, 784, dim) where dim is the total number of possible values of xs, like dim=(xs.max() + 1). Btw, typically MLP returns a tensor in real domain, if so it is better to use loc as logits, rather than probs.

Hi all.
I replaced the dist.Categorical for xs in last line of model() with dist.RelaxedOneHotCategorical, and the problem solved.
So I think the error was because dist.Categorcal kinda automatically enumerates and takes 200 in input mat [200, 784] as enumerated dimensions, while other dists in prior are not discrete distributions.

Oh, if your output is one hot, you need to use OneHot likelihood. Enumeration is not related to the issue.

Really appreciate, thanks bro.
I’m currently not familiar with the tool and sometimes confused about some of the realization of distributions, elegent mechanisms, tricks et al. But really thankfull about your work.
Hopefully I’ll catch up.

Hi, fehiepsi, can you please help me out of this issue dist.Bernoulli recognize the right batch_size, but dist.Categorical won’t - Tutorials - Pyro Discussion Forum ?

same code for Bernoulli distribution would recognize the right batch_shape and event_shape, as:

x = pyro.sample("x", dist.Bernoulli(loc, validate_args=False).to_event(1), obs=xs)
b = dist.Bernoulli(loc, validate_args=False).to_event(1)
print(f"b batch_shape:{b.batch_shape}")
print(f"b event_shape:{b.event_shape}")

would output:

b batch_shape:torch.Size([10, 200])
b event_shape:torch.Size([784])

but

x = pyro.sample("x", dist.Categorical(loc, validate_args=False).to_event(1), obs=xs)
b = dist.Categorical(loc, validate_args=False).to_event(1)
print(f"b batch_shape:{b.batch_shape}")
print(f"b event_shape:{b.event_shape}")

output:

b batch_shape:torch.Size([10])
b event_shape:torch.Size([200])

I dig a littler futher into the torch.dist.Categorical implementation, and maybe you could help me out in that topic.

See my comment above: Question about batch_size in the semi-supervised VAE demo - #4 by fehiepsi

Hi, fehiepsi
I understand that you mean the dist.Categorical returns a vector, the dim of which is possible categories of the discrete variable (for dice the dim will be 6). So if there are 6 possible categories in each pixel, then the encoder needs to return [200, 784, 6]

But I had been confused by the code in ss_vae that the dist.Bernoulli do not need an additional dim, like [prob_0, prob_1], until I saw from the example in torch.distributions.bernoulli.Bernoulli that the Bernoulli only takes in the probablity of possitive labels, so an additional dim is not required.

To complete my stupid questions and provide hints for other new hands: I follow fehiepsi’s suggestion, output an [784 * 6] dim vector and reshape it as [784, 6], or possibly [enumeration_size, batch_size, 784, 6].