Sorry @eb8680_2, I took some time to read the docs for tensor shape and enumeration again.
And came to this solution, which I can train with now, however I am not very sure if my thinking is right.
I will share my code with you here so that I can have your feedback properly.
My goal is to replicate the ‘SS-VAE’ with MNIST dataset with my own dataset, where my labels are 0 or 1 or None ( making up the semi supervised model ). Therefore I want to use a Bernoulli to sample for my label when it is None, and this is making it a discrete enumeration issue.
The only way I can reach to enumerated values of Bernoulli is when I use to_event(1)
. I have my alpha
( the parameter for Bernoulli
) in shape of [20,1]
, so when I use to_event(1),
the batch_shape
becomes 20
, and event_shape
becomes 1
.
Here comes my first confusion:
- I use
to_event(1)
in Bernoulli sampling, otherwise I cannot feed its output to encoder_z
which needs the output of Bernoulli
sampling. (Encoder_z
and Decoder
needs input of [20,1] for y
).
Without to_event(1)
, Bernoulli gives [2,1]
due to broadcasting, which I cant use as input for encoder_z
.
My model and guide are as follows:
def model(self,xs,ys=None):
pyro.module("ss_vae",self)
with pyro.plate('my_plate',dim=-1):
batch_size = xs.size(0) #which will be 20
z_mean = torch.zeros(batch_size, self.z_dim)
print("z_mean has a shape of", z_mean.shape)
z_standdev = torch.ones(batch_size, self.z_dim)
# sample from a gaussian
zs = pyro.sample("z", dist.Normal(z_mean,z_standdev).to_event(1))
print("shape of z after being sampled in the model is", zs.shape)
assert (batch_size == 20)
alpha_known = torch.Tensor(0.5 * np.ones((20,1))).view(-1)
print("alpha known for observed y has the shape of", alpha_known.shape)
ys = pyro.sample("y", dist.Bernoulli(probs=alpha_known).to_event(1),obs=ys)
print("supervised y looks like:", ys)
#decode z and y then
print("decoder will use a supervised y with a shape of", ys.shape)
re_input,re_scale = self.decoder.forward(ys,zs)
pyro.sample("x", dist.Normal(re_input,re_scale).to_event(1), obs=xs)
return re_input,re_scale
And the guide:
def guide(self,xs,ys=None):
with pyro.plate('my_plate', dim =-1):
if ys is None:
alpha = self.encoder_y.forward(xs)
alpha = alpha.squeeze(1)
print("alpha shape from the encoder_y", alpha.shape)
assert (alpha.shape == (20,))
ys = pyro.sample("y", dist.Bernoulli(alpha).to_event(1))
print("Bernoulli to_event y shape is", ys.shape)
print("Bernoulli to_event y values are", ys)
print("encoder z uses a guide y of shape is", ys.shape)
loc,scale = self.encoder_z.forward(xs,ys)
pyro.sample("z",dist.Normal(loc,scale).to_event(1))
The outputs and shapes I receive are:
z_mean has a shape of torch.Size([20, 12])
shape of z after being sampled in the model is torch.Size([20, 12])
alpha known for observed y has the shape of torch.Size([20])
supervised y looks like: tensor([1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0.,
1., 1.])
decoder will use a supervised y with a shape of torch.Size([20])
decoder y shape torch.Size([20])
decoder z shape torch.Size([20, 12])
alpha shape from the encoder_y torch.Size([20])
Bernoulli to_event y shape is torch.Size([20])
Bernoulli to_event y values are tensor([1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0.,
1., 1.])
encoder z uses a guide y of shape is torch.Size([20])
However to make it happen, I also need to a shape control in the encoder_z:
class Encoder_z(nn.Module):
def __init__(self,z_dim, hidden_dim,hidden_dim_mid,hidden_dim_bef):
super().__init__()
self.encoderzl1 = nn.Linear(input_size+output_size,hidden_dim)
self.encoderzl2 = nn.Linear(hidden_dim,hidden_dim_mid)
...
def forward(self,x,y):
# Check dimension of y so this can be used with and without enumeration
if y.dim() < 2:
y = y.unsqueeze(1)
data_combined = torch.cat((x,y), 1)
...
I use config_enumerate in my guide:
ssvae = SSVAE()
infer={"enumerate": "parallel"}
ssvae.guide = config_enumerate(ssvae.guide,"parallel", expand=False)
pyro.clear_param_store()
svi = SVI(ssvae.model, ssvae.guide, optimizer, loss=TraceEnum_ELBO(max_plate_nesting=1,strict_enumeration_warning=True))
-
Then, in this case, I can manage to do enumeration with to_event(1)
but then are my Bernoulli’s in model and guide still independent in their sampling? ( because they are under pyro.plate) Because I need each Bernoulli sampling to be independent from one another…( a new alpha = a new choice of Bernoulli
)
-
Or do I need to say pyro.plate('my_plate', 20)
to declare independence for 20 values? If I apply pyro.plate('my_plate', 20)
and remove to_event(1)
from Bernoullis
in model and guide (both), then Bernoulli starts to give [2,1]
.
If I keep both pyro.plate('my_plate', 20)
and to_event(1)
in model and guide, then the shape of Bernoulli is [20,20]
which is not what I want.
I would appreciate to hear back from you when you have time, as I am very confused with the dependencies and why to_event(1)
gives a vector and plate
does not do any effect… I am quite stuck here…
Thank you and I apologize for the super long question!