Find (most likely) latent variables for observations

Assume I have a model x = S * z, where both S and z are estimated using SVI. S is a model parameter and z is my latent variable following a normal distribution.
Now, I would like to get the most likely z_i for an observation x_i. How can I do that with internal Pyro functions?

def model(data):
    s = pyro.param("s", torch.randn(G, K))

    z0_loc   = pyro.param("z0_loc", torch.zeros(K))
    z0_scale = pyro.param("z0_scale", torch.eye(K, K), constraint=constraints.positive_definite)
    z0       = pyro.sample("z0", dist.MultivariateNormal(z0_loc, z0_scale))
    sigma0 = pyro.param("sigma0", torch.eye(G, G), constraint=constraints.positive_definite)

    mean = s @ z0
    
    with pyro.plate("data", len(data)):
        pyro.sample("obs", dist.MultivariateNormal(mean, sigma0), obs=data.T)

guide = pyro.infer.autoguide.AutoNormal(model)

as formulated your model has a single global latent variable z0. this is shared between all observations. there are no “local” latent variables private to each observation (i.e. there are no z_i). based on your question, you probably need to reformulate your model first.

after training the guide you can inspect results by using e.g. the median method: Automatic Guide Generation — Pyro documentation

1 Like

Thanks a lot!