@eb8680_2 here is a runnable example of the code. I really tried to make close to the original problem.
Here I am trying to work with batch_dim = -1 and I figured out that the problem seems to be the masking and specifically the .unsqueeze(-1).
Do you think I should drop the .unsqueeze(-1)? How should I move to a model with batch_dim=-2?
with poutine.mask(mask=(t < torch.ones(batch_size) * seq_length).unsqueeze(-1))
from unittest import TestCase
import os
import torch
import pyro
import pyro.distributions as dist
from torch.distributions import constraints
from pyro import poutine
from pyro.infer import SVI, Trace_ELBO, TraceEnum_ELBO, config_enumerate, infer_discrete
from pyro.infer.autoguide import AutoDiagonalNormal, AutoDelta
from pyro.ops.indexing import Vindex
pyro.enable_validation()
pyro.set_rng_seed(0)
num_seq = 100
seq_length = 1000
import torch
from torch import nn as nn
class TestEmitter(nn.Module):
def __init__(self, w_dim=1, z_dim=1):
super(TestEmitter, self).__init__()
self.lin_w_to_hidden = nn.Linear(w_dim, z_dim)
self.softmax = nn.Softmax(dim=0)
def forward(self, w, z):
w = w.unsqueeze(-1)
w_hidden = self.lin_w_to_hidden(w)
out = w_hidden
out_sm = self.softmax(out) # alpha is size of z
return out_sm
@config_enumerate
def model(inp):
num_states = 6
batch_size = 30
x = inp['x']
y = inp['y']
state_emitter = TestEmitter(w_dim=1, z_dim=6)
obs_emitter = TestEmitter(w_dim=1, z_dim=25)
pyro.module("state_emitter", state_emitter)
pyro.module("obs_emitter", obs_emitter)
with poutine.mask(mask=True):
probs_lat = pyro.sample("probs_lat",
dist.Dirichlet(0.5 * torch.eye(num_states) + 0.5 / (num_states - 1)).to_event(1))
z = torch.Tensor([0]).type(torch.FloatTensor)
y_hat = torch.Tensor([0]).type(torch.FloatTensor)
with pyro.plate("sequence", size=num_seq, subsample_size=batch_size) as batch:
for t in pyro.markov(range(0, seq_length - 1)):
# with poutine.mask(mask=(t < torch.ones(batch_size) * seq_length)):
with poutine.mask(mask=(t < torch.ones(batch_size) * seq_length).unsqueeze(-1)):
px = state_emitter(x[t, batch].type(torch.FloatTensor), z)
z = pyro.sample(f"z_{t}", dist.Categorical(Vindex(probs_lat)[..., px.argmax(dim=1), :]))
py = obs_emitter(z.type(torch.FloatTensor), y_hat)
y_hat = pyro.sample(f"y_{t}", dist.Categorical(py), obs=y[t, batch])
def generate_data():
x = torch.randint(0, 30, size=(seq_length, num_seq))
y = torch.randint(0, 25, size=(seq_length, num_seq))
seq = dict(x=x, y=y)
return seq
def print_shapes(model, guide, seq, first_available_dim=-3):
guide_trace = poutine.trace(guide).get_trace(seq)
model_trace = poutine.trace(
poutine.replay(poutine.enum(model, first_available_dim), guide_trace)).get_trace(seq)
print(model_trace.format_shapes())
class TestPyroARDBN(TestCase):
def test_architecture(self):
seq = generate_data()
guide = AutoDelta(poutine.block(model, expose=["probs_lat"]))
pyro.clear_param_store()
elbo = TraceEnum_ELBO(max_plate_nesting=3)
print_shapes(model, guide, seq, first_available_dim=-3)
elbo.loss(model, guide, seq)