Pyro performs dramatically slower than PyMC3 with Normalizing Flows on stochastic-volatility model inference

Thanks, Fritz. Following your suggestions I redefined the model:

# define model
def model(returns):
    
    phi = pyro.sample("phi", dist.Beta(20, 1.5))
    phi = 2 * phi - 1
    sigma2 = pyro.sample('sigma2', dist.InverseGamma(2.5, 0.025))
    mu = pyro.sample("mu", dist.Normal(0, 10))
    
    T = len(returns)
    means_white_noise = torch.tensor(mu*(1-phi)).repeat(T)
    vars_white_noise = torch.tensor(sigma2 ** 0.5).repeat(T)
    
    with pyro.plate("data", len(returns)):
        h = pyro.sample('h', dist.Normal(means_white_noise, vars_white_noise))
        h = pyro.ops.tensor_utils.convolve(h, phi ** torch.arange(T))[:T]

        y = pyro.sample('y', dist.Normal(0., (h / 2.).exp()), obs=returns)

However:

  1. It doesn’t look like improved computation time much (on a short toy runs it was even the opposite direction);
  2. now in the very end at the stage of neutra.transform_sample() I get this error (please note, torch.Size([5]) stands for len(returns), while torch.Size([10]) for number of samples):
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-49-6ca0cf6bcf2d> in <module>
    129     args = parser.parse_args(args=[])
    130 
--> 131     outDict = main(args)

<ipython-input-49-6ca0cf6bcf2d> in main(args)
     98 
     99         print(zs.shape)
--> 100         samples = neutra.transform_sample(zs)
    101 
    102         outDict['nf_neutra_mcmc'] = mcmc

~/anaconda3/lib/python3.7/site-packages/pyro/infer/reparam/neutra.py in transform_sample(self, latent)
    102         x_unconstrained = self.transform(latent)
    103         transformed_samples = {}
--> 104         for site, value in self.guide._unpack_latent(x_unconstrained):
    105             transform = biject_to(site["fn"].support)
    106             x_constrained = transform(value)

~/anaconda3/lib/python3.7/site-packages/pyro/infer/autoguide/guides.py in _unpack_latent(self, latent)
    609             event_dim = site["fn"].event_dim + len(unconstrained_shape) - len(constrained_shape)
    610             unconstrained_shape = broadcast_shape(unconstrained_shape,
--> 611                                                   batch_shape + (1,) * event_dim)
    612             unconstrained_value = latent[..., pos:pos + size].view(unconstrained_shape)
    613             yield site, unconstrained_value

~/anaconda3/lib/python3.7/site-packages/pyro/distributions/util.py in broadcast_shape(*shapes, **kwargs)
    140             elif reversed_shape[i] != size and (size != 1 or strict):
    141                 raise ValueError('shape mismatch: objects cannot be broadcast to a single shape: {}'.format(
--> 142                     ' vs '.join(map(str, shapes))))
    143     return tuple(reversed(reversed_shape))
    144 

ValueError: shape mismatch: objects cannot be broadcast to a single shape: torch.Size([5]) vs torch.Size([10])

Could you point me out if you think I am doing something incorrectly or have any workarounds?

Thanks,
Arturs