Hi,
I am using Pyro version 1.4.0. I am modelling a time series and I use SVI to also learn a inference network that outputs the distribution of latent variables at each time step (given only the input at that specific time step).
def guide(self, x):
pyro.module('latent_dependency_vae', self)
b_s = x.shape[1]
with pyro.plate('batch_plate', b_s, dim=-1):
mu_z, sigma_z = self.inference_net.x_to_z(x)
z = torch.empty(self.T, self.num_samples_elbo, b_s, self.dim_z)
for t in range(self.T):
z[t] = pyro.sample('z_{}'.format(t), dist.Normal(mu_z[t], sigma_z[t]).to_event(1))
mu_w, sigma_w = self.inference_net.z_to_w(z)
for t in range(self.T):
w = pyro.sample('w_{}'.format(t), dist.Normal(mu_w[t], sigma_w[t]).to_event(1))
In my model, I have a simple Markovian prior on the latent variables:
def prior(self, b_s):
z = torch.empty((self.T, self.num_samples_elbo, b_s, self.dim_z))
w = torch.empty((self.T, self.num_samples_elbo, b_s, self.dim_w))
z[0] = pyro.sample('z_0', dist.Normal(torch.zeros(self.dim_z), 1e-3).to_event(1))
w[0] = pyro.sample('w_0', dist.Normal(torch.zeros(self.dim_w), 1e-3).to_event(1))
for t in range(1, self.T):
z[t] = pyro.sample('z_{}'.format(t), dist.Normal(z[t-1], 1e-3).to_event(1))
w[t] = pyro.sample('w_{}'.format(t), dist.Normal(w[t-1], 1e-3).to_event(1))
return z, w
And these are used to generate the emission:
def model(self, x):
pyro.module('latent_dependency_vae', self)
b_s = x.shape[1]
with pyro.plate('batch_plate', b_s, dim=-1):
z, w = self.prior(b_s)
mu_x = self.inference_net.latents_to_x(z, w)
for t in range(self.T):
x_gen = pyro.sample('x_{}'.format(t), dist.Normal(mu_x[t], 1e-3).to_event(1), obs=x[t])
If I then try to use the Predictive class like so:
dim_x = 10
dim_z = 10
dim_w = 10
T = 5
num_samples = 10
vae_model = LatentDependencyVAE(dim_x, dim_z, dim_w, T, num_samples)
predictive = Predictive(poutine.uncondition(vae_model.model), guide=vae_model.guide, num_samples=num_samples, parallel=True)
minibatch_size = 1
x = torch.randn(T, minibatch_size, dim_x)
samples = predictive(x)
where LatentDependencyVAE is the wrapper containing prior, model and guide:
class LatentDependencyVAE(torch.nn.Module):
def __init__(self, dim_x, dim_z, dim_w, T, num_samples_elbo):
super().__init__()
self.dim_z = dim_z
self.dim_w = dim_w
self.T = T
self.num_samples_elbo = num_samples_elbo
self.inference_net = InferenceNet(dim_x, dim_z, dim_w)
I run into the problem that
model_trace.nodes[site]["fn"].batch_shape
in _predictive() on line 70 takes w to incorrectly have a batch_shape of [10, 1] when it is actually [1] and the [10] is the number of samples of w which belongs under the vectorized plate.
Clearly this is a flaw in how I have written the model and the guide to explicitly include num_samples, but I am not sure how to write it otherwise. Sorry for the long read and thank you for any help!