How to use the TraceTMC_ELBO correctly?

Hello,
I’m working on a dynamical variational autoencoder (a VRNN model) with pyro, which almostly followed the deep markov model example. I’m wondering how to use TraceTMC_ELBO correctly. In the dmm, it just wraped the guide with config_enumerate. But it raises an error when I do the same in my code. So, did I miss something?

def model(self, x, annealing_factor=1.0):
    b, l, _ = x.shape
    pyro.module("vrnn", self)
    h = x.new_zeros([b, self.hidden_dim])

    feature_x = self.feature_extra_x(x)
    with pyro.plate("data", b):
        for t in range(1, l + 1):
            z_loc, z_scale = self.theta_norm_1(h)
            with poutine.scale(None, annealing_factor):
                z_t = pyro.sample(f"z_{t}", dist.Normal(z_loc, z_scale).to_event(1))

            z_t = self.feature_extra_z(z_t)
            z_with_h = torch.cat([z_t, h], dim=1)
            x_loc, x_scale = self.theta_norm_2(z_with_h)
            pyro.sample(f"obs_{t}", dist.Normal(x_loc, x_scale).to_event(1), obs=x[:, t - 1, :])
            x_with_z = torch.cat([feature_x[:, t - 1, :], z_t], dim=1)
            h = self.rnn(x_with_z, h)

def guide(self, x, annealing_factor=1.0):
    b, l, _ = x.shape
    h = x.new_zeros([b, self.hidden_dim])

    feature_x = self.feature_extra_x(x)
    with pyro.plate("data", b):
        for t in range(1, l + 1):
            xt = feature_x[:, t - 1, :]
            h_with_x = torch.cat([h, xt], dim=1)
            z_loc, z_scale = self.phi_norm(h_with_x)
            with poutine.scale(None, annealing_factor):
                z_t = pyro.sample(f"z_{t}", dist.Normal(z_loc, z_scale).to_event(1))

            z_t = self.feature_extra_z(z_t)
            x_with_z = torch.cat([xt, z_t], dim=1)
            h = self.rnn(x_with_z, h)

@eb8680_2

Hi @christophe, it’s hard to say what could be causing your error without a stack trace or runnable code reproducing the problem. However, since TraceTMC_ELBO is not well documented I will give some background for you and any future users with similar questions.

I will also explain why I don’t think your specific model and guide are compatible with TraceTMC_ELBO even if you were doing everything correctly, unlike the DMM; you will need to use pyro.infer.Trace_ELBO (or pyro.infer.Renyi_ELBO if you want an IWAE estimator). If you just want to skip to TraceTMC_ELBO usage examples, take a look at these unit tests.

First, TraceTMC_ELBO is very similar both conceptually and in implementation to exact enumeration over discrete model variables via pyro.infer.TraceEnum_ELBO (in fact, you could recover TraceTMC_ELBO from TraceEnum_ELBO by introducing auxiliary discrete variables in the model indexing sample subsets), so all of the same background and caveats discussed in the Tensor Shapes and Enumeration tutorials apply.

To use TraceTMC_ELBO, you need to indicate that each variable should be sampled with TMC, either via the infer argument to pyro.sample:

pyro.sample(f"z_{t}", Normal(...), infer={
    "enumerate": "parallel", 
    "num_samples": num_tmc_samples,
    "expand": False, 
    "tmc": "diagonal"
})

or via config_enumerate applied to the guide:

@config_enumerate(
    default="parallel",
    num_samples=num_tmc_samples,
    expand=False,
    tmc="diagonal",
)
def guide(...):
    ...

You would also need to make sure your model supports broadcasting on the left as described in this section of the Tensor Shapes tutorial (e.g. the dimension arguments to your torch.cat calls should be negative) and wrap each for-loop with pyro.markov, as described in this section of the enumeration tutorial:

for t in pyro.markov(range(1, l+1)):
    ...

However, I don’t think your model and guide are compatible with TraceTMC_ELBO anyway, because each z_t depends directly on all previous values of z_t through your RNN state variable h, meaning that the number of possible sampled values of h grows exponentially in your sequence length l as discussed in the enumeration tutorial. Compare that to the DMM example, where each z_t is conditionally independent of the others given the previous z_t and all observations. You could cook up a different, biased TMC estimator that simply ignores the paths through the RNN, but that’s beyond the scope of this post.

Finally, a couple of caveats: first, I have not had much success with TraceTMC_ELBO in the few experiments I have done with it, although that’s probably because I am less concerned with model learning and because there are implementation choices and missing statistical optimizations in our implementation that might be increasing gradient variance.

I should also say that the implementation in pyro.infer.TraceTMC_ELBO is somewhat inefficient because it draws samples it never uses, which can be especially problematic when you have vector-valued latent variables as in the DMM and limited GPU memory. The other implementation pyro.contrib.funsor.infer.TraceTMC_ELBO does not have this problem (see this example for how to use it), but has high constant-factor overhead for other reasons unrelated to the TMC estimator.

1 Like

Thanks for your kind answer, I’ve used the Trace_ELBO back.

1 Like