Hi,
I’m trying to write a manual guide for a model. I have a 2D array of parameters that I’ve defined like this:
def model(D, E):
with numpyro.plate("x1", D):
with numpyro.plate("x2", E):
numpyro.sample("phi", dist.Beta( 5.0, 5.0))
I’ve written a guide that uses independent Gaussians to parameterise the phi
site in the model
def guide_indep(D, E):
with numpyro.plate("x1", D):
with numpyro.plate("x2", E):
loc = numpyro.param("loc", 0.0)
scale = numpyro.param(
"scale",
0.1,
constraint=dist.constraints.positive
)
numpyro.sample(
"phi",
dist.TransformedDistribution(
dist.Normal(
loc,
scale
),
dist.transforms.SigmoidTransform()
)
)
this works.
I would like to write a guide that parameterises the phi
s as a multivariate gaussian distribution that allows for correlations between the phi
parameters (so, having a covariance matrix with non-zero off diagonal elements). I can’t figure out how to get this to work, as dist.MultivariateNormal
returns samples containing more than 1 element, at that messes up the array sizes inside the plates. I would like to do something like this:
def guide_MVN(D, E):
loc = numpyro.param("loc", jnp.zeros(D * E))
cov_mat = numpyro.param("cov_mat", (jnp.diag(jnp.ones(4) * 0.09)) + (jnp.ones((4, 4)) * 0.01)), constraint=dist.constraints.positive)
phi_latent = numpyro.sample(
"phi_latent",
dist.TransformedDistribution(
dist.MultivariateNormal(
loc,
covariance_matrix=cov_mat
),
dist.transforms.SigmoidTransform()
)
)
with numpyro.plate("x1", D):
with numpyro.plate("x2", E):
numpyro.sample("phi", dist.Delta(phi_latent))
but this doesn’t work because the phi_latent
site doesn’t exist in the model.
I found this thread which seems to be asking the same thing, but I can’t figure out how to make that solution work.
Would appreciate any help. Thanks