Hi all,
I am using Pyro version 1.4.0. I have a toy HMM
class HMM(torch.nn.Module):
def __init__(self, dim_x, dim_z, T):
super().__init__()
self.dim_z = dim_z
self.T = T
self.model_net = ModelNet(dim_x, dim_z)
def prior(self, b_s):
z = [None] * self.T
z[0] = pyro.sample('z_0', dist.Normal(torch.zeros(self.dim_z), 1.).to_event(1))
for t in range(1, self.T):
z[t] = pyro.sample('z_{}'.format(t), dist.Normal(z[t-1], 1.).to_event(1))
return torch.stack(z, -1).transpose(-1,-2)
def model(self, x):
pyro.module('hmm', self)
b_s = x.shape[1]
with pyro.plate('batch_plate', b_s, dim=-1):
z = self.prior(b_s)
print('example 0 particle 0:', z[0,0])
print('example 0 particle 1:', z[1,0])
print('example 1 particle 0:', z[0,1])
print('example 1 particle 1:', z[1,1])
mu_x = self.model_net.z_to_x(z)
for t in range(self.T):
pyro.sample('x_{}'.format(t), dist.Normal(mu_x[...,t,:], 1.).to_event(1), obs=x[t])
return z
In the main method, I set up an instance of this HMM and I seek to find the MAP solution for some batch of examples:
dim_x = 10
T = 20
minibatch_size = 5
x = torch.randn(T, minibatch_size, dim_x)
dim_z = 10
hmm = HMM(dim_x, dim_z, T)
num_particles = 3
elbo = Trace_ELBO(num_particles=num_particles, vectorize_particles=True, max_plate_nesting=1)
optimizer = Adam({})
guide = AutoDelta(hmm.model, init_loc_fn=init_to_sample)
svi = SVI(hmm.model, guide, optimizer, elbo)
loss = svi.step(x)
However, the printout from inside the model shows that the z samples from prior() are different across the batch dim but not across the particle dim.
I have not set either the torch or pyro random seeds, and if I manually add the vectorised plate and sample from the prior, the samples are then different both across the particle and batch dimensions, as expected:
max_plate_nesting = 2
with pyro.plate("num_particles_vectorized", num_particles, dim=-max_plate_nesting):
z = poutine.uncondition(hmm.model)(x)
What am I missing here? Thank you in advance!