Hey guys,
As far as I can tell, pyro does not allow a sample site used in a guide to be downstream of a discrete sample site, when the discrete sample site is using parallel inference. For example, if I have the simple structure:
for t, y in pyro.markov(enumerate(data)):
s = pyro.sample(f"s_{t}", dist.Categorical(transition_matrix[s]),infer={"enumerate": "parallel"})
mask = s==1
dist_z = dist.Normal(
mask * z * Vindex(x_ar_coeff)[1]
+ (~mask) * z * Vindex(x_ar_coeff)[0]
,x_ar_noise
).to_event(1)
z = pyro.sample(f"z_{t}", dist_z)
# MORE CODE BELOW, BUT NOT RELEVANT
, and even if I use an autoguide, Pyro will throw an error saying that the guide and model sample sites do not match at site z_{t}, for any t. Obviously the reason is that during enumeration, dist_z will be expanded to the shape of the fully-enumerated variable s, but Pyro does not seem to allow this, since the guide and model shapes have to match. Is there a way to get the sample shapes to match at the sample site pyro.sample(f"z_{t}", dist_z)
? Broadcasting from the guide to the model would fix the problem, but pyro doesn’t broadcast here, instead throwing an error.
Hi @mike_schoehals , can you provide a minimal reproducible code (with all imports) please? I can have a look at it.
@ordabayev
Thanks for taking a look at. Here is a full end-to-end example:
import numpy as np
import torch
from torch.distributions import constraints
import pyro
from pyro import poutine
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO, TraceEnum_ELBO, TraceGraph_ELBO
from pyro.infer.autoguide import AutoNormal, AutoDiagonalNormal
N = 10
data = torch.randn(N)
# NOTE THAT THIS IS ESSENTIALLY A REGIME-SWITCHING AR(1) STATE-SPACE MODEL.
def model(data,is_enum=False):
trans_prob = pyro.param("trans_prob",torch.randn(2).exp(),constraint=constraints.simplex)
transition_matrix = torch.stack([trans_prob,torch.flip(trans_prob,(0,))],dim=0)
z=torch.tensor([0])
s=torch.tensor([0])
for t, y in pyro.markov(enumerate(data)):
#for t, y in enumerate(data):
s = pyro.sample(f"s_{t}", dist.Categorical(transition_matrix[s]),infer={"enumerate": "parallel"})
mask = s==1
dist_z = dist.Normal(
mask * z * 0.99
+ (~mask) * z * 0.5
,1
).to_event(1)
z = pyro.sample(f"z_{t}", dist_z)
obs = pyro.sample(f"y_{t}", dist.Normal(z,1).to_event(1), obs=y)
# this is required in this regime-switching state-space model.
# z will be constant along the enum dimension, and we only want 1 value to continue to the next iteration of the loop.
if is_enum:
z = z[0,...].squeeze().unsqueeze(-1)
optim = pyro.optim.Adam({'lr': 1e-1})
guide = AutoDiagonalNormal(poutine.block(model,hide=[f's_{t}' for t in range(data.shape[0])]))
svi = SVI(model,guide, optim, loss=TraceEnum_ELBO(retain_graph=True,max_plate_nesting=1))#,num_particles=10,vectorize_particles=True))
pyro.set_rng_seed(0)
for i in range(2000):
loss = svi.step(data,True)
if not i % 10:
print('loss: ', loss)
@ordabayev No worries if you haven’t looked at this, but if you had, I was curious if you’d found anything. This case is useful for regime-switching time series, and being able to use parallel enumeration allows for very low variance estimates (as you guys are aware).
Hi @mike_schoehals , sorry for not replying. I had an initial look and couldn’t figure it out. Didn’t have time since then to have a deeper look. I’ll post it here if I have any progress.
@ordabayev no worries. thanks for the reply!