Using guide for downstream sample sites of enumerated site - Issues

Hey guys,

As far as I can tell, pyro does not allow a sample site used in a guide to be downstream of a discrete sample site, when the discrete sample site is using parallel inference. For example, if I have the simple structure:

for t, y in pyro.markov(enumerate(data)):
      s = pyro.sample(f"s_{t}", dist.Categorical(transition_matrix[s]),infer={"enumerate": "parallel"})

      mask = s==1

      dist_z = dist.Normal(
                      mask * z * Vindex(x_ar_coeff)[1] 
                      + (~mask) * z * Vindex(x_ar_coeff)[0]
                      ,x_ar_noise
                  ).to_event(1)

      z = pyro.sample(f"z_{t}", dist_z)

    # MORE CODE BELOW, BUT NOT RELEVANT

, and even if I use an autoguide, Pyro will throw an error saying that the guide and model sample sites do not match at site z_{t}, for any t. Obviously the reason is that during enumeration, dist_z will be expanded to the shape of the fully-enumerated variable s, but Pyro does not seem to allow this, since the guide and model shapes have to match. Is there a way to get the sample shapes to match at the sample site pyro.sample(f"z_{t}", dist_z)? Broadcasting from the guide to the model would fix the problem, but pyro doesn’t broadcast here, instead throwing an error.

Hi @mike_schoehals , can you provide a minimal reproducible code (with all imports) please? I can have a look at it.

@ordabayev

Thanks for taking a look at. Here is a full end-to-end example:

import numpy as np
import torch
from torch.distributions import constraints
import pyro
from pyro import poutine
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO, TraceEnum_ELBO, TraceGraph_ELBO 
from pyro.infer.autoguide import AutoNormal, AutoDiagonalNormal


N = 10
data = torch.randn(N)


# NOTE THAT THIS IS ESSENTIALLY A REGIME-SWITCHING AR(1) STATE-SPACE MODEL.

def model(data,is_enum=False):
    trans_prob = pyro.param("trans_prob",torch.randn(2).exp(),constraint=constraints.simplex)
    transition_matrix = torch.stack([trans_prob,torch.flip(trans_prob,(0,))],dim=0)

    z=torch.tensor([0])
    s=torch.tensor([0])

    for t, y in pyro.markov(enumerate(data)):
        #for t, y in enumerate(data):
        s = pyro.sample(f"s_{t}", dist.Categorical(transition_matrix[s]),infer={"enumerate": "parallel"})

        mask = s==1

        dist_z = dist.Normal(
                        mask * z * 0.99
                        + (~mask) * z * 0.5
                        ,1
                    ).to_event(1)
        
        z = pyro.sample(f"z_{t}", dist_z)


        obs = pyro.sample(f"y_{t}", dist.Normal(z,1).to_event(1), obs=y)

        # this is required in this regime-switching state-space model.
        # z will be constant along the enum dimension, and we only want 1 value to continue to the next iteration of the loop.
        if is_enum:
            z = z[0,...].squeeze().unsqueeze(-1)



optim = pyro.optim.Adam({'lr': 1e-1})
guide =  AutoDiagonalNormal(poutine.block(model,hide=[f's_{t}' for t in range(data.shape[0])]))

svi = SVI(model,guide, optim, loss=TraceEnum_ELBO(retain_graph=True,max_plate_nesting=1))#,num_particles=10,vectorize_particles=True))
pyro.set_rng_seed(0)
for i in range(2000):
    loss = svi.step(data,True)
    if not i % 10:
        print('loss: ', loss)

@ordabayev No worries if you haven’t looked at this, but if you had, I was curious if you’d found anything. This case is useful for regime-switching time series, and being able to use parallel enumeration allows for very low variance estimates (as you guys are aware).

Hi @mike_schoehals , sorry for not replying. I had an initial look and couldn’t figure it out. Didn’t have time since then to have a deeper look. I’ll post it here if I have any progress.

@ordabayev no worries. thanks for the reply!