Partially reparametrize by drawing multiple discrete samples at one site?

Suppose we have a model with a discrete latent variable y that depends on a continuous latent variable z, but y has too many categories to enumerate (e.g. a genetic sequence):

def model(data):
    z = pyro.sample("z", SomeContinuousDistribution())
    y = pyro.sample("y", SomeDiscreteDistribution(z))
    pyro.sample("obs", SomeLikelihood(y), obs=data)

To learn a posterior over z we could use Trace_ELBO to which uses REINFORCE aka the score function estimator, or if y has only a few categories we could use TraceEnum_ELBO which marginalizes out y (aka Rao-Blackwellizes aka enumerates) and reparametrizes the gradients of z.

But intuitively an inference algorithm should be able to draw multiple samples of y for each sample of z and compute hybrid gradients combining REINFORCE from the logmeanexp aggregate likelihood over the y samples, while backproping through reparametrized samples of z based on the relative likelihoods of the y samples.

Is this what @eb8680_2’s TraceTMC_ELBO does? If not is there another way to accomplish this? Is this a named algorithm in the literature? Are the resulting gradients unbiased? Are there caveats?

IIUC that sounds similar to what the Storchastic paper is describing, as is illustrated in Figure 1 there.

1 Like