Hi again!
I would just like to know what is the issue with my approach when masking some elements in the sequence from the marginal likelihood calculation (“x”). I also, depending on the learning set-up (supervised, unsupervised and semi supervised) want to mask some data point’s target (“c”) also from the likelihood calculation. I provide a runable example:
import torch
from torch import tensor
import pyro
from pyro import sample,plate
import pyro.distributions as dist
import pyro.poutine as poutine
from pyro.infer import SVI,TraceEnum_ELBO
from pyro.optim import ClippedAdam
def model(x,obs_mask,x_class,class_mask):
"""
:param x: Data [N,L,feat_dim]
:param obs_mask: Data sites to mask [N,L]
:param x_class: Target values [N,]
:param class_mask: Target values mask [N,]
"""
with plate("inner", dim=-1):
z = sample("z",dist.Normal(torch.zeros((2,5)),torch.ones((2,5))).to_event(1))
#Highlight: Target
if learning_type == "unsupervised":
c = sample("c",dist.Categorical(logits= torch.Tensor([[3,5],[10,8]])).to_event(1))
elif learning_type == "semisupervised":
c = sample("c", dist.Categorical(logits=torch.Tensor([[3, 5], [10, 8]])).to_event(1),obs=x_class,obs_mask=class_mask)
else:
c = sample("c",dist.Categorical(logits= torch.Tensor([[3,5],[10,8]])).to_event(1),obs=x_class)
#Highlight: Sequence reconstruction
with plate("outer",dim=-2):
logits = torch.Tensor([[[10,2,3],[8,2,1],[3,6,1]],
[[1,2,7],[0,2,1],[2,7,8]]])
aa = sample("x",dist.Categorical(logits= logits),obs=x,obs_mask=obs_mask)
return z,c,aa
def guide(x,obs_mask,x_class,class_mask):
"""
:param x: Data [N,L,feat_dim]
:param obs_mask: Data sites to mask [N,L]
:param x_class: Target values [N,]
:param class_mask: Target values mask [N,]
"""
with plate("inner", dim=-1):
z = sample("z",dist.Normal(torch.zeros((2,5)),torch.ones((2,5))).to_event(1))
if learning_type == "unsupervised":
c = sample("c", dist.Categorical(logits=torch.Tensor([[3, 5], [10, 8]])).to_event(1),infer={'enumerate': 'parallel'})
elif learning_type == "semisupervised":
c = sample("c",dist.Categorical(logits= torch.Tensor([[3,5],[10,8]])).to_event(1).mask(class_mask),infer={'enumerate': 'parallel'})
else: #supervised
c = 0
#Highlight: Sequence reconstruction: When using obs_mask in the model it keeps complaining about unobserved sites. That is why added this segment here
with plate("outer",dim=-2):
logits = torch.Tensor([[[10,2,3],[8,2,1],[3,6,1]],
[[1,2,7],[0,2,1],[2,7,8]]])
aa = sample("x",dist.Categorical(logits= logits).mask(~obs_mask),infer={'enumerate': 'parallel'}) #Still not sure if this is correct
return z,c,aa
if __name__ == "__main__":
learning_ops = {0:"supervised",
1:"unsupervised",
2:"semisupervised"}
learning_type = learning_ops[0]
x = tensor([[0,2,1],
[0,1,1]])
obs_mask = tensor([[1,0,0],[1,1,0]],dtype=bool) #I need a mask like this to work over the len dimension
x_class = tensor([0,1])
class_mask = tensor([1,0],dtype=bool) #Also this one, over the batch dimension
guide_tr = poutine.trace(guide).get_trace(x,obs_mask,x_class,class_mask)
model_tr = poutine.trace(poutine.replay(model, trace=guide_tr)).get_trace(x,obs_mask,x_class,class_mask)
monte_carlo_elbo = model_tr.log_prob_sum() - guide_tr.log_prob_sum()
print(monte_carlo_elbo)
svi = SVI(model,guide,loss=TraceEnum_ELBO(),optim=ClippedAdam(dict()))
svi.step(x,obs_mask,x_class,class_mask)
To start with, in the supervised approach, it pops a warning, which becomes an error with my actual model:
/home/.../miniconda3/lib/python3.8/site-packages/pyro/util.py:288: UserWarning: Found non-auxiliary vars in guide but not model, consider marking these infer={'is_auxiliary': True}:
{'x'}
warnings.warn(
/home/.../miniconda3/lib/python3.8/site-packages/pyro/util.py:303: UserWarning: Found vars in model but not guide: {'x_unobserved'}
warnings.warn(f"Found vars in model but not guide: {bad_sites}")
Feel free to split the models and the guides in 3 different ones according to learning types (I just though like this was more condensed) .Thanks in advance!