Hi, fehiepsi, thanks very much!
I read the enumeration tutorial and realized the addtional dimension in [10, 200, 10] is to enumerate 10 classes of hand written digits for discrete variable ys=[200, 10] (batch_size=200).
But I’ve got another question, my I ask if you could help me figure it out? Here is my code:
I first define a Logistic norm encoder to replace dirichlet prior
class LN_Encoder(nn.Module):
# Base class for the encoder net, used in the guide
# Use logistic norm to replace Dirichlet
def __init__(self, input_size, var_size, hidden, dropout):
super().__init__()
assert type(hidden) == int, "We got only one hidden layer in LN_Encoder, so thelayer size must be int"
self.drop = nn.Dropout(dropout) # to avoid component collapse
self.fc1 = nn.Linear(input_size, hidden)
self.fc2 = nn.Linear(hidden, hidden)
self.fcmu = nn.Linear(hidden, var_size)
self.fclv = nn.Linear(hidden, var_size)
# NB: here we set `affine=False` to reduce the number of learning parameters
# See https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm1d.html
# for the effect of this flag in BatchNorm1d
self.bnmu = nn.BatchNorm1d(var_size, affine=False) # to avoid component collapse
self.bnlv = nn.BatchNorm1d(var_size, affine=False) # to avoid component collapse
def forward(self, inputs):
h = F.softplus(self.fc1(inputs))
h = F.softplus(self.fc2(h))
h = self.drop(h)
# μ and Σ are the outputs
logtheta_loc = self.bnmu(self.fcmu(h))
logtheta_logvar = self.bnlv(self.fclv(h))
logtheta_scale = (0.5 * logtheta_logvar).exp() # Enforces positivity
return logtheta_loc, logtheta_scale
Then I follow the orginal ss_vae code to define my model:
from utils.custom_mlp import MLP, Exp, ConcatModule
class myVAE(nn.Module):
def setup_networks(self):
self.decoder_x = MLP(
[self.theta_dim] + hidden_sizes + [self.args.num_voxels_per_robo],
activation=nn.Softplus,
output_activation=nn.Sigmoid,
allow_broadcast=self.allow_broadcast,
use_cuda=self.use_cuda,
)
self.decoder_theta = MLP(...)
self.encoder_theta = MLP(...)
self.encoder_y = MLP(...)
self.encoder_g = LN_Encoder(...)
def model(self, xs, hs=None, tou=torch.tensor(1.0)):
pyro.module("my_vae", self)
bs = xs.size(0) # batch size
options = dict(dtype=xs.dtype, device=xs.device)
with pyro.plate("data"):
# ys
alpha_prior = torch.ones(bs, self.h_dim, **options) / (1.0 * self.h_dim)
ys = pyro.sample("y", dist.RelaxedOneHotCategorical(probs=alpha_prior, temperature=tou, validate_args=False), obs=ys)
# g
log_g_loc = torch.zeros(bs, self.g_dim, **options)
log_g_scale = torch.ones(bs, self.g_dim, **options)
log_g = pyro.sample("log_g", dist.Normal(log_g_loc, log_g_scale).to_event(1))
gs = F.softmax(log_g, -1)
# theta
theta_loc, theta_scale = self.decoder_theta.forward([gs, ys]) # refer to ladder or hierarchical vae
thetas = pyro.sample("theta", dist.Normal(theta_loc, theta_scale).to_event(1))
# sample x
loc = self.decoder_x.forward(thetas)
xs_hat = pyro.sample("x", dist.Categorical(loc, validate_args=False).to_event(1), obs=xs)
def guide(self, xs, ys=None, tou=torch.tensor(1.0)):
with pyro.plate("data"):
if hs is None:
alpha = self.encoder_y.forward(xs)
ys = pyro.sample("y", dist.RelaxedOneHotCategorical(probs=alpha, temperature=tou, validate_args=False))
log_g_loc, log_g_scale = self.encoder_g(torch.cat([xs, ys], dim=-1))
log_g = pyro.sample("log_g", dist.Normal(log_g_loc, log_g_scale).to_event(1))
gs = F.softmax(log_g, -1)
theta_loc, theta_scale = self.encoder_theta.forward([xs, gs])
thetas = pyro.sample("theta", dist.Normal(theta_loc, theta_scale).to_event(1))
Then I’ve encountered two problems:
- Compared with the original ss_vae, there is no auto enumeration. I take this as there is no discrete distribution, except sampling x:
xs_hat = pyro.sample("x", dist.Categorical(loc, validate_args=False).to_event(1), obs=xs)
- the event shape for dist.Categorical is unexpected:
loc = self.decoder_x.forward(thetas)
print(f"loc shape:{loc.shape}")
a = dist.Categorical(loc, validate_args=False).to_event(1)
print(f"a batch_shape:{a.batch_shape}")
print(f"a event_shape:{a.event_shape}")
I used “.to_event(1)” to ensure the event dimension, but got:
loc shape:torch.Size([200, 784])
a batch_shape:torch.Size([])
a event_shape:torch.Size([200])
So the dimension sampled xs_hat is [200], not [784], and when I change it to “to_event(0)”, the result would be:
loc shape:torch.Size([200, 784])
a batch_shape:torch.Size([200])
a event_shape:torch.Size([])
No event dimensions, while the batch_shape is right. And the size of xs_hat is also right:
xs_hat = pyro.sample("x", dist.Categorical(loc, validate_args=False).to_event(1), obs=xs)
print(f"xs_hat shape:{xs_hat.shape}")
and we have:
xs_hat shape:torch.Size([200, 784])
However, still in the model(), the results for
b = dist.Normal(theta_loc, theta_scale).to_event(1)
print(f"b batch_shape:{b.batch_shape}")
print(f"b event_shape:{b.event_shape}")
are right, and I got:
b batch_shape:torch.Size([200])
b event_shape:torch.Size([8])
I think this weired problem may be the cause of mismatch error in broadcasting tensors:
Traceback (most recent call last):
File "D:\Programming\VAE\myLDA\train_myvae_v0.2.py", line 363, in <module>
main(EXAMPLE_RUN)
File "D:\Programming\VAE\myLDA\train_myvae_v0.2.py", line 184, in main
epoch_losses_sup, epoch_losses_unsup = run_inference_for_epoch(
File "D:\Programming\VAE\myLDA\train_myvae_v0.2.py", line 70, in run_inference_for_epoch
new_loss = losses[loss_id].step(xs)
File "C:\Users\bigya\anaconda3\envs\pyroEnv\lib\site-packages\pyro\infer\svi.py", line 145, in step
loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
File "C:\Users\bigya\anaconda3\envs\pyroEnv\lib\site-packages\pyro\infer\traceenum_elbo.py", line 451, in loss_and_grads
for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs):
File "C:\Users\bigya\anaconda3\envs\pyroEnv\lib\site-packages\pyro\infer\traceenum_elbo.py", line 394, in _get_traces
yield self._get_trace(model, guide, args, kwargs)
File "C:\Users\bigya\anaconda3\envs\pyroEnv\lib\site-packages\pyro\infer\traceenum_elbo.py", line 339, in _get_trace
model_trace, guide_trace = get_importance_trace(
File "C:\Users\bigya\anaconda3\envs\pyroEnv\lib\site-packages\pyro\infer\enum.py", line 75, in get_importance_trace
model_trace.compute_log_prob()
File "C:\Users\bigya\anaconda3\envs\pyroEnv\lib\site-packages\pyro\poutine\trace_struct.py", line 230, in compute_log_prob
log_p = site["fn"].log_prob(
File "C:\Users\bigya\anaconda3\envs\pyroEnv\lib\site-packages\torch\distributions\independent.py", line 99, in log_prob
log_prob = self.base_dist.log_prob(value)
File "C:\Users\bigya\anaconda3\envs\pyroEnv\lib\site-packages\pyro\distributions\torch.py", line 141, in log_prob
return super().log_prob(value)
File "C:\Users\bigya\anaconda3\envs\pyroEnv\lib\site-packages\torch\distributions\categorical.py", line 125, in log_prob
value, log_pmf = torch.broadcast_tensors(value, self.logits)
File "C:\Users\bigya\anaconda3\envs\pyroEnv\lib\site-packages\torch\functional.py", line 74, in broadcast_tensors
return _VF.broadcast_tensors(tensors) # type: ignore[attr-defined]
RuntimeError: The size of tensor a (784) must match the size of tensor b (200) at non-singleton dimension 1
Can you please help me figure this out? Cause I kinda spire multiple days for debugging these errors.
Thanks in advance!