 # How do I extract variational posterior on untransformed parameters

I am trying to implement variational inference for the parameters of a 1d Gaussian, based on this example from McElreath book. This uses a model of the form

data_i \sim N(\mu,\sigma)
\mu &\sim N(178, 20)
\sigma &\sim U(0,50)


I choose a mean field approximation

q(\mu,\sigma)=N(\mu|m,s) Ga(\sigma|a,b)


I can extract the posterior for the mean just fine, but for \sigma, it seems that numpyro is quietly transforming the parameters in some way (presumably because of the U(0,50) prior), since the posterior mean for sigma (namely a/b) does not make sense. However, if I compute the Predictive from this posterior, it all works fine.

So my question is: how can I figure out what transformation is being applied, and how can I compute posterior mean of sigma in the original unconstrained form?

Code is below.

url = 'https://raw.githubusercontent.com/fehiepsi/rethinking-numpyro/master/data/Howell1.csv'
Howell1 = pd.read_csv(url, sep=';')
d = Howell1
d2 = d[d.age >= 18]
data = d2.height.values

rng_key = random.PRNGKey(0)
rng_key, rng_key_ = random.split(rng_key)

def model(data):
mu = numpyro.sample("mu", dist.Normal(178, 20))
sigma = numpyro.sample("sigma", dist.Uniform(0, 50))
numpyro.sample("obs", dist.Normal(mu, sigma), obs=data)

def guide(data):
data_mean = jnp.mean(data)
data_std = jnp.std(data)
m = numpyro.param("m", data_mean)
s = numpyro.param("s", 10, constraint=constraints.positive)
a = numpyro.param("a", data_std, constraint=constraints.positive)
b = numpyro.param("b", 1, constraint=constraints.positive)
mu = numpyro.sample("mu", dist.Normal(m, s))
sigma = numpyro.sample("sigma", dist.Gamma(a, b))

optimizer = numpyro.optim.Momentum(step_size=0.001, mass=0.1)
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
nsteps = 2000
svi_result = svi.run(rng_key_, nsteps, data=data)

print(svi_result.params)
a = np.array(svi_result.params['a'])
b = np.array(svi_result.params['b'])
m = np.array(svi_result.params['m'])
s = np.array(svi_result.params['s'])

print('empirical mean', jnp.mean(data))
print('empirical std', jnp.std(data))

print(r'posterior mean and std of $\mu$')
print([m, s])

print(r'posterior mean and std of $\sigma$')
post = dist.Beta(a,b)
print([post.mean, jnp.sqrt(post.variance)])

predictive = Predictive(guide, params=svi_result.params, num_samples=1000)
samples = predictive(rng_key, data)

print(samples['mu'].shape)
print(samples['sigma'].shape)
print_summary(samples, 0.95, False)


Hi @murphyk,
I’m not so familiar with how to extract the unscaled parameters from NumPyro’s MCMC results (@fehiepsi?), but you can take a look at the constraint registry where it looks like interval constraints like [0,50] are transformed using a sigmoid function. If you have a distribution at hand, you can transform via

from numpyro.distributions import biject_to

constrained_samples = my_distribution.sample(...)
transform = biject_to(my_distribution.support)
unconstrained_samples = transform.inv(constrained_samples)


Hi @murphyk, you are right that the returned parameters are in constrained space. I think we can add a keyword argument to svi.constrain_fn so that you can do

svi_result = svi.run(...)
unconstrained_params = svi.constrain_fn(svi_result.params, invert=True)


What do you think? Currently, svi.constrain_fn performs unconstrained -> constrained.

Another solution is to expose optim_state to svi_result, so you can do

svi_result = svi.run(...)
unconstrained_params = svi.optim.get_params(svi_result.optim_state)


Probably this is a typo: this should be dist.Gamma(a, b)

I’m a bit confused here. Consider code snippet 4.28 from McElreath book, as implemented by @fehiepsi. This computes a 2d Gaussian approximation to (mu,sigma).
Sampling from this returns values on the original scale:

samples = guide.sample_posterior(random.PRNGKey(1), params, (nsamples,))
print_summary(samples, 0.95, False)
mean       std    median      2.5%     97.5%     n_eff     r_hat
mu    154.32      1.68    154.31    151.06    157.57   4327.88      1.00
sigma      7.70      1.25      7.59      5.45     10.22   4173.88      1.00


But if I extract the parameters from the MVN, it seems sigma is transformed

post = guide.get_posterior(params)
print(post.mean)
154.335  -1.716]


Suppose I want to work with the MVN coming out of the Laplace approximation. I tried transforming, to no avail

from numpyro.distributions import biject_to

unconstrained_samples = post.sample(rng_key, sample_shape=(nsamples,))
transform = biject_to(post.support)
constrained_samples = transform.inv(unconstrained_samples)

print(unconstrained_samples.shape)
print(jnp.mean(unconstrained_samples, axis=0))
print(jnp.mean(constrained_samples, axis=0))

(5000, 2)
[154.326  -1.724]
[154.326  -1.724]


Yes, you are right. Replacing with Gamma(a,b) actually solves the problem: now the posterior mean is 9.1, which is on the right scale :). But I still have an issue with the laplace approximation being unconstrained. See my code here:

I confirmed that the Laplace approximation is transforming the sigma parameter s (with uniform prior [0,50]) using t=logit(s/50) and then fitting a Gaussian, so I can recover the posterior marginal of s using s=sigmoid(t)*50 (see code below). But what is the Pyronic way to transform distribution objects (and/or samples), since the code snippet above does not seem to work:

def logit(p):
return jnp.log(p/(1-p))

def sigmoid(a):
return 1/(1+jnp.exp(-a))

scale=50; print(logit(7.7/scale)); print(sigmoid(-1.7)*scale)

-1.7035668
7.723263


I see your point. Currently, AutoContinuous guides transform a model with constrained latent variables to a model with unconstrained variables (following ADVI) because the auto normal guides have real support, rather than constrained support. For auto guides, calling guide.get_posterior() will return an approximated posterior of the transformed model (with unconstrained variables) and calling guide.sample_posterior() will return samples in constrained space, i.e. samples for the original models). In your case, assuming that you want to work with the MVN posterior and collect unconstrained samples

unconstrained_samples = post.sample(rng_key, sample_shape=(nsamples,))


you can transform them back using a hidden method (>___<)

guide._unpack_and_constrain(unconstrained_samples, svi_result.params)


What it does is to unpack the flatten samples (MVN.sample returns an array, rather than a dict mapping names to values) to unconstrained samples and then transform the unconstrained samples to constrained samples. If you think this method is useful, we will expose it.

That works, thanks