More doubts on masking - Runnable example

Hi again! :slight_smile:

I would just like to know what is the issue with my approach when masking some elements in the sequence from the marginal likelihood calculation (“x”). I also, depending on the learning set-up (supervised, unsupervised and semi supervised) want to mask some data point’s target (“c”) also from the likelihood calculation. I provide a runable example:

import torch
from torch import tensor
import pyro
from pyro import sample,plate
import pyro.distributions as dist
import pyro.poutine as poutine
from pyro.infer import SVI,TraceEnum_ELBO
from pyro.optim import ClippedAdam
def model(x,obs_mask,x_class,class_mask):
    """
    :param x: Data [N,L,feat_dim]
    :param obs_mask: Data sites to mask [N,L]
    :param x_class: Target values [N,]
    :param class_mask: Target values mask [N,]
    """
    with plate("inner", dim=-1):
        z = sample("z",dist.Normal(torch.zeros((2,5)),torch.ones((2,5))).to_event(1))
        #Highlight: Target
        if learning_type == "unsupervised":
            c = sample("c",dist.Categorical(logits= torch.Tensor([[3,5],[10,8]])).to_event(1))
        elif learning_type == "semisupervised":
            c = sample("c", dist.Categorical(logits=torch.Tensor([[3, 5], [10, 8]])).to_event(1),obs=x_class,obs_mask=class_mask)
        else:
            c = sample("c",dist.Categorical(logits= torch.Tensor([[3,5],[10,8]])).to_event(1),obs=x_class)
        #Highlight: Sequence reconstruction
        with plate("outer",dim=-2):
            logits =  torch.Tensor([[[10,2,3],[8,2,1],[3,6,1]],
                                    [[1,2,7],[0,2,1],[2,7,8]]])
            aa = sample("x",dist.Categorical(logits= logits),obs=x,obs_mask=obs_mask)

        return z,c,aa

def guide(x,obs_mask,x_class,class_mask):
    """
    :param x: Data [N,L,feat_dim]
    :param obs_mask: Data sites to mask [N,L]
    :param x_class: Target values [N,]
    :param class_mask: Target values mask [N,]
    """
    with plate("inner", dim=-1):
        z = sample("z",dist.Normal(torch.zeros((2,5)),torch.ones((2,5))).to_event(1))
        if learning_type == "unsupervised":
            c = sample("c", dist.Categorical(logits=torch.Tensor([[3, 5], [10, 8]])).to_event(1),infer={'enumerate': 'parallel'})
        elif learning_type == "semisupervised":
            c = sample("c",dist.Categorical(logits= torch.Tensor([[3,5],[10,8]])).to_event(1).mask(class_mask),infer={'enumerate': 'parallel'})
        else: #supervised
            c = 0
        #Highlight: Sequence reconstruction: When using obs_mask in the model it keeps complaining about unobserved sites. That is why added this segment here
        with plate("outer",dim=-2):
            logits =  torch.Tensor([[[10,2,3],[8,2,1],[3,6,1]],
                                    [[1,2,7],[0,2,1],[2,7,8]]])
            aa = sample("x",dist.Categorical(logits= logits).mask(~obs_mask),infer={'enumerate': 'parallel'}) #Still not sure if this is correct

        return z,c,aa


if __name__ == "__main__":
    learning_ops = {0:"supervised",
                     1:"unsupervised",
                    2:"semisupervised"}
    learning_type = learning_ops[0]
    x = tensor([[0,2,1],
                [0,1,1]])
    obs_mask = tensor([[1,0,0],[1,1,0]],dtype=bool) #I need a mask like this to work over the len dimension
    x_class = tensor([0,1])
    class_mask = tensor([1,0],dtype=bool) #Also this one, over the batch dimension

    guide_tr = poutine.trace(guide).get_trace(x,obs_mask,x_class,class_mask)
    model_tr = poutine.trace(poutine.replay(model, trace=guide_tr)).get_trace(x,obs_mask,x_class,class_mask)
    monte_carlo_elbo = model_tr.log_prob_sum() - guide_tr.log_prob_sum()
    print(monte_carlo_elbo)

    svi = SVI(model,guide,loss=TraceEnum_ELBO(),optim=ClippedAdam(dict()))
    svi.step(x,obs_mask,x_class,class_mask)

To start with, in the supervised approach, it pops a warning, which becomes an error with my actual model:

/home/.../miniconda3/lib/python3.8/site-packages/pyro/util.py:288: UserWarning: Found non-auxiliary vars in guide but not model, consider marking these infer={'is_auxiliary': True}:
{'x'}
  warnings.warn(
/home/.../miniconda3/lib/python3.8/site-packages/pyro/util.py:303: UserWarning: Found vars in model but not guide: {'x_unobserved'}
  warnings.warn(f"Found vars in model but not guide: {bad_sites}")

Feel free to split the models and the guides in 3 different ones according to learning types (I just though like this was more condensed) .Thanks in advance!

Hi @artistworking ,

In the documentation for the sample primitive it says:

  • obs_mask (bool or Tensor) – Optional boolean tensor mask of shape broadcastable with fn.batch_shape. If provided, events with mask=True will be conditioned on obs and remaining events will be imputed by sampling. This introduces a latent sample site named name + "_unobserved" which should be used by guides.

As your warning message suggests you need to name the variable as x_unobserved in the guide for the unobserved (masked out) x.

I couldn’t find any example code in the docs or tutorials that showcases how to use obs_mask. Please feel free to create a feature request for this :slight_smile:

2 Likes

Oh, it means to “literally name it” “x_unobserved”? I did not think of that

1 Like

I have updated the code, now it works for the supervised mode (in the toy example, not in the model I have). But both for the semisupervised and unsupervised versions I get an error that I am not sure what it means. My intuition tells me that I have to do sequential inference of the discrete variables in the guide, because the parallel one might not be implemented?

import torch
from torch import tensor
from pyro import sample,plate
import pyro.distributions as dist
import pyro.poutine as poutine
from pyro.infer import SVI,TraceEnum_ELBO
from pyro.optim import ClippedAdam
def model(x,obs_mask,x_class,class_mask):
    """
    :param x: Data [N,L,feat_dim]
    :param obs_mask: Data sites to mask [N,L]
    :param x_class: Target values [N,]
    :param class_mask: Target values mask [N,]
    :return:
    """
    with plate("inner", dim=-1):
        z = sample("z",dist.Normal(torch.zeros((2,5)),torch.ones((2,5))).to_event(1))
        #Highlight: Class inference
        if learning_type == "unsupervised":
            c = sample("c",dist.Categorical(logits= torch.Tensor([[3,5],[10,8]])).to_event(1))
        elif learning_type == "semisupervised":
            c = sample("c", dist.Categorical(logits=torch.Tensor([[3, 5], [10, 8]])).to_event(1),obs=x_class,obs_mask=class_mask)
        else:
            c = sample("c",dist.Categorical(logits= torch.Tensor([[3,5],[10,8]])).to_event(1),obs=x_class)
        #Highlight: Sequence reconstruction
        with plate("outer",dim=-2):
            logits =  torch.Tensor([[[10,2,3],[8,2,1],[3,6,1]],
                                    [[1,2,7],[0,2,1],[2,7,8]]])
            aa = sample("x",dist.Categorical(logits= logits),obs=x,obs_mask=obs_mask)

        return z,c,aa

def guide(x,obs_mask,x_class,class_mask):
    """
    :param x: Data [N,L,feat_dim]
    :param obs_mask: Data sites to mask [N,L]
    :param x_class: Target values [N,]
    :param class_mask: Target values mask [N,]
    """
    with plate("inner", dim=-1):
        z = sample("z",dist.Normal(torch.zeros((2,5)),torch.ones((2,5))).to_event(1))
        if learning_type == "unsupervised":
            c = sample("c", dist.Categorical(logits=torch.Tensor([[3, 5], [10, 8]])).to_event(1),infer={'enumerate': 'parallel'})
        elif learning_type == "semisupervised":
            c = sample("c_unobserved",dist.Categorical(logits= torch.Tensor([[3,5],[10,8]])).to_event(1).mask(~class_mask),infer={'enumerate': 'parallel'})
        else: #supervised
            c = 0
        # #Highlight: Sequence reconstruction
        with plate("outer",dim=-2):
            logits =  torch.Tensor([[[10,2,3],[8,2,1],[3,6,1]],
                                    [[1,2,7],[0,2,1],[2,7,8]]])
            aa = sample("x_unobserved",dist.Categorical(logits= logits).mask(~obs_mask),infer={'enumerate': 'parallel'}) #Still not sure if this is correct

        return z,c,aa


if __name__ == "__main__":
    learning_ops = {0:"supervised",
                     1:"unsupervised",
                    2:"semisupervised"}
    learning_type = learning_ops[1]
    x = tensor([[0,2,1],
                [0,1,1]])
    obs_mask = tensor([[1,0,0],[1,1,0]],dtype=bool) #I need a mask like this
    x_class = tensor([0,1])
    class_mask = tensor([1,0],dtype=bool)

    guide_tr = poutine.trace(guide).get_trace(x,obs_mask,x_class,class_mask)
    model_tr = poutine.trace(poutine.replay(model, trace=guide_tr)).get_trace(x,obs_mask,x_class,class_mask)
    monte_carlo_elbo = model_tr.log_prob_sum() - guide_tr.log_prob_sum()
    print(monte_carlo_elbo)

    svi = SVI(model,guide,loss=TraceEnum_ELBO(),optim=ClippedAdam(dict()))
    svi.step(x,obs_mask,x_class,class_mask)

File “…”, line 43, in guide
c = sample(“c”, dist.Categorical(logits=torch.Tensor([[3, 5], [10, 8]])).to_event(1),infer={‘enumerate’: ‘parallel’})
File “/home/…/miniconda3/lib/python3.8/site-packages/pyro/primitives.py”, line 163, in sample
apply_stack(msg)
File “/home/…/miniconda3/lib/python3.8/site-packages/pyro/poutine/runtime.py”, line 213, in apply_stack
frame._process_message(msg)
File “/home/…/miniconda3/lib/python3.8/site-packages/pyro/poutine/messenger.py”, line 162, in _process_message
return method(msg)
File “/home/…/miniconda3/lib/python3.8/contextlib.py”, line 75, in inner
return func(*args, **kwds)
File “/home/…/miniconda3/lib/python3.8/site-packages/pyro/poutine/enum_messenger.py”, line 175, in _pyro_sample
value = enumerate_site(msg)
File “/home/lys/miniconda3/lib/python3.8/site-packages/pyro/poutine/enum_messenger.py”, line 109, in enumerate_site
value = dist.enumerate_support(expand=msg[“infer”].get(“expand”, False))
File “/home/…/miniconda3/lib/python3.8/site-packages/torch/distributions/independent.py”, line 108, in enumerate_support
raise NotImplementedError(“Enumeration over cartesian product is not implemented”)
NotImplementedError: Enumeration over cartesian product is not implemented
Trace Shapes:
Param Sites:
Sample Sites:
z dist 2 | 5
value 2 | 5
Trace Shapes:
Param Sites:
Sample Sites:
z dist 2 | 5
value 2 | 5

If you use .to_event then the distribution cannot be enumerated. Here is the issue I opened about it a while ago.

Yes, you have to use for loop so that you don’t have to_event for the Categorical distribution.

@ordabayev Ok, so something like this:

import torch
from torch import tensor
from pyro import sample,plate
import pyro.distributions as dist
import pyro.poutine as poutine
from pyro.infer import SVI,TraceEnum_ELBO
from pyro.optim import ClippedAdam
def model(x,obs_mask,x_class,class_mask):
    """
    :param x: Data [N,L,feat_dim]
    :param obs_mask: Data sites to mask [N,L]
    :param x_class: Target values [N,]
    :param class_mask: Target values mask [N,]
    :return:
    """
    with plate("inner", dim=-1):
        z = sample("z",dist.Normal(torch.zeros((2,5)),torch.ones((2,5))).to_event(1))
        #Highlight: Class inference
        if learning_type == "unsupervised":
            #c = sample("c",dist.Categorical(logits= torch.Tensor([[3,5],[10,8]])).to_event(1))
            class_logits = torch.Tensor([[3, 5], [10, 8]])
            for t, y in enumerate(x_class):
                c = sample(f"c_{t}", dist.Categorical(class_logits[t]))
        elif learning_type == "semisupervised":
            #c = sample("c", dist.Categorical(logits=torch.Tensor([[3, 5], [10, 8]])).to_event(1),obs=x_class,obs_mask=class_mask)
            class_logits = torch.Tensor([[3, 5], [10, 8]])
            for t, y in enumerate(x_class):
                c = sample(f"c_{t}", dist.Categorical(class_logits[t]),obs=x_class[t],obs_mask=class_mask[t])
        else:
            c = sample("c",dist.Categorical(logits= torch.Tensor([[3,5],[10,8]])).to_event(1),obs=x_class)
        #Highlight: Sequence reconstruction
        with plate("outer",dim=-2):
            logits =  torch.Tensor([[[10,2,3],[8,2,1],[3,6,1]],
                                    [[1,2,7],[0,2,1],[2,7,8]]])
            aa = sample("x",dist.Categorical(logits= logits),obs=x,obs_mask=obs_mask)

        return z,c,aa

def guide(x,obs_mask,x_class,class_mask):
    """
    :param x: Data [N,L,feat_dim]
    :param obs_mask: Data sites to mask [N,L]
    :param x_class: Target values [N,]
    :param class_mask: Target values mask [N,]
    """
    with plate("inner", dim=-1):
        z = sample("z",dist.Normal(torch.zeros((2,5)),torch.ones((2,5))).to_event(1))
        if learning_type == "unsupervised":
            class_logits = torch.Tensor([[3, 5], [10, 8]])
            #c = sample("c", dist.Categorical(logits=torch.Tensor([[3, 5], [10, 8]])).to_event(1),infer={'enumerate': 'parallel'})
            for t, y in enumerate(x_class):
                c = sample(f"c_{t}_unobserved", dist.Categorical(class_logits[t]),infer={"enumerate": "parallel"})
        elif learning_type == "semisupervised":
            #c = sample("c_unobserved",dist.Categorical(logits= torch.Tensor([[3,5],[10,8]])).to_event(1),infer={'enumerate': 'parallel'})
            class_logits = torch.Tensor([[3, 5], [10, 8]])
            #c = sample("c", dist.Categorical(logits=torch.Tensor([[3, 5], [10, 8]])).to_event(1),infer={'enumerate': 'parallel'})
            for t, y in enumerate(x_class):
                c = sample(f"c_{t}_unobserved", dist.Categorical(class_logits[t]),infer={"enumerate": "parallel"})
        else: #supervised
            c = None
        # #Highlight: Sequence reconstruction
        with plate("outer",dim=-2):
            logits =  torch.Tensor([[[10,2,3],[8,2,1],[3,6,1]],
                                    [[1,2,7],[0,2,1],[2,7,8]]])
            aa = sample("x_unobserved",dist.Categorical(logits= logits).mask(~obs_mask),infer={'enumerate': 'parallel'})


        return z,c,aa


if __name__ == "__main__":
    learning_ops = {0:"supervised",
                     1:"unsupervised",
                    2:"semisupervised"}
    learning_type = learning_ops[1]
    x = tensor([[0,2,1],
                [0,1,1]])
    obs_mask = tensor([[1,0,0],[1,1,0]],dtype=bool) #I need a mask like this
    x_class = tensor([0,1])
    class_mask = tensor([1,0],dtype=bool)

    guide_tr = poutine.trace(guide).get_trace(x,obs_mask,x_class,class_mask)
    model_tr = poutine.trace(poutine.replay(model, trace=guide_tr)).get_trace(x,obs_mask,x_class,class_mask)
    monte_carlo_elbo = model_tr.log_prob_sum() - guide_tr.log_prob_sum()
    print(monte_carlo_elbo)

    svi = SVI(model,guide,loss=TraceEnum_ELBO(),optim=ClippedAdam(dict()))
    svi.step(x,obs_mask,x_class,class_mask)

The semisupervised does not throw warnings or errors (in the toy model), but the unsupervised does:

/home/.../miniconda3/lib/python3.8/site-packages/pyro/util.py:288: UserWarning: Found non-auxiliary vars in guide but not model, consider marking these infer={'is_auxiliary': True}:
{'c_1_unobserved', 'c_0_unobserved'}
  warnings.warn(
/home/.../miniconda3/lib/python3.8/site-packages/pyro/util.py:303: UserWarning: Found vars in model but not guide: {'c_0', 'c_1'}
  warnings.warn(f"Found vars in model but not guide: {bad_sites}")

And this time I named them “_unobserved”

@ordabayev I have opened an issue to generate a tutorial for masking because it is indeed confusing but also a needed thing (Request for more masking tutorials · Issue #3187 · pyro-ppl/pyro · GitHub)

1 Like