I am working on a problem using pyro’s SVI.
the model draws several latents that are independent from each other, and has some deterministic computations that combine them to a final observation.
In the guide I’m doing some conditioning (not mean field)
def model(data, ...):
latent1 = pyro.sample('latent1', dist1)
...
latentn = pyro.sample('latent1', dist1)
y_final_obs = pyro.sample(dist.Normal(deterministic_computation), noise_level, obs=data)
def guide(data, ...):
net1 = pyro.module('net1', net1)
latent1_params = net1(data)
latent1 = pyro.sample('latent1', dist1_guide(latent1_params))
net2 = pyro.module('net2', net2)
latent2_params = net2(data, latent1)
latent2 = pyro.sample('latent2', dist2_guide(latent2_params))
...
latentn = pyro.sample('latentn', distn_guide(latentn_params))
I have been using a normal distribution for latent1, and now want to incorporate a custom pre-trained deep boltzmann generator (using normalizing flows), which is coded up in pure pytorch. I’m getting familiar with boltzmann generator coding base, but one way to think about it is as a variational autoencoder that is invertible. There are some deep layers, then sampling from a torch distribution (samples = torch.randn(...)
or samples = torch.distributions.Cauchy(...).sample()
), and then some more deep layers.
I’m trying to engineer a good way to plug this into pyros SVI. I wonder if it’s as simple as replacing the samples = torch.distributions.Cauchy(...)
with samples = pyro.sample('latent', torch.distributions.Cauchy(...))
. Or does the pyro.sample
statement have to be exposed to the model in a certain way without being further down the stack (inside of wrappers)?
(In the guide, I’ll have another distribution that maps to the space of the Boltzmann generator latent space, but that will use another net that takes in the data. That part seems stright forward to me.)
So the question is more about the model aspect (not the guide).
Hope that is clear!
We’re deciding what is the least amount of work: re-writing our own SVI in pure pytorch, or getting the Boltzmann generator to work with pyro’s SVI. I like pyro’s SVI functionality like num_particles
…