I am trying to implement variational inference for the parameters of a 1d Gaussian, based on this example from McElreath book. This uses a model of the form
data_i \sim N(\mu,\sigma)
\mu &\sim N(178, 20)
\sigma &\sim U(0,50)
I choose a mean field approximation
q(\mu,\sigma)=N(\mu|m,s) Ga(\sigma|a,b)
I can extract the posterior for the mean just fine, but for \sigma, it seems that numpyro is quietly transforming the parameters in some way (presumably because of the U(0,50) prior), since the posterior mean for sigma (namely a/b) does not make sense. However, if I compute the Predictive from this posterior, it all works fine.
So my question is: how can I figure out what transformation is being applied, and how can I compute posterior mean of sigma in the original unconstrained form?
Code is below.
url = 'https://raw.githubusercontent.com/fehiepsi/rethinking-numpyro/master/data/Howell1.csv'
Howell1 = pd.read_csv(url, sep=';')
d = Howell1
d2 = d[d.age >= 18]
data = d2.height.values
rng_key = random.PRNGKey(0)
rng_key, rng_key_ = random.split(rng_key)
def model(data):
mu = numpyro.sample("mu", dist.Normal(178, 20))
sigma = numpyro.sample("sigma", dist.Uniform(0, 50))
numpyro.sample("obs", dist.Normal(mu, sigma), obs=data)
def guide(data):
data_mean = jnp.mean(data)
data_std = jnp.std(data)
m = numpyro.param("m", data_mean)
s = numpyro.param("s", 10, constraint=constraints.positive)
a = numpyro.param("a", data_std, constraint=constraints.positive)
b = numpyro.param("b", 1, constraint=constraints.positive)
mu = numpyro.sample("mu", dist.Normal(m, s))
sigma = numpyro.sample("sigma", dist.Gamma(a, b))
optimizer = numpyro.optim.Momentum(step_size=0.001, mass=0.1)
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
nsteps = 2000
svi_result = svi.run(rng_key_, nsteps, data=data)
print(svi_result.params)
a = np.array(svi_result.params['a'])
b = np.array(svi_result.params['b'])
m = np.array(svi_result.params['m'])
s = np.array(svi_result.params['s'])
print('empirical mean', jnp.mean(data))
print('empirical std', jnp.std(data))
print(r'posterior mean and std of $\mu$')
print([m, s])
print(r'posterior mean and std of $\sigma$')
post = dist.Beta(a,b)
print([post.mean, jnp.sqrt(post.variance)])
predictive = Predictive(guide, params=svi_result.params, num_samples=1000)
samples = predictive(rng_key, data)
print(samples['mu'].shape)
print(samples['sigma'].shape)
print_summary(samples, 0.95, False)