Sampling Normal Distribution

Hi all,

I’m wanting to sample from a normal distribution with a mean value (loc) of 10 in my model definition. However, while debugging, I found that when I sample, it is different to instantiating the distribution and sampling from that. That is, running

numpyro.sample('mu', dist.Normal(10,0.1), sample_shape=(5,))

gives the output DeviceArray([-0.0852356 , 1.6111875 , 0.04703999, 0.42531633, -0.08525848], dtype=float32), whereas running,

test = dist.Normal(10,0.1)
test.sample(random.PRNGKey(0),sample_shape=(5,))

gives DeviceArray([10.0187845, 9.871666 , 9.972891 , 10.124906 , 10.024447 ], dtype=float32).

Is this the expected behaviour? Why do the values output when using numpyro.sample not seem to be from the defined distribution?

Thank you in advance!

This seems strange. Could you provide some reproducible code?

Apologies for the late reply, this code shows the issue (although doesn’t do anything)

import numpyro
from numpyro.infer import MCMC, HMC
import numpyro.distributions as dist
from jax import random

def model(data):
    mu = numpyro.sample('mu', dist.Normal(10,0.1), sample_shape=(data['n'],))
    print("mu: " + str(mu))
    test = dist.Normal(10,0.1)
    tmp = test.sample(random.PRNGKey(0),sample_shape=(data['n'],))
    print("test sample: " + str(tmp))    

data = {}
data['n'] = 5

kernel = HMC(model)
mcmc= MCMC(kernel, num_samples=1000, num_warmup=1000, num_chains=1, progress_bar=True)
mcmc.run(random.PRNGKey(0),data)

When I run the output is:

mu: [ 0.3011222 0.7353368 1.944623 -1.2138557 1.8295264]
test sample: [10.0187845 9.871666 9.972891 10.124906 10.024447 ]

Using the following versions:
numpyro 0.6.0
jax 0.2.10
jaxlib 0.1.62

I see, I thought you were sampling from Normal distribution directly. If you are using MCMC, initial values depend on your initial strategy. You can try to set

HMC(model, init_strategy=numpyro.infer.init_to_sample)

to get initial values from your prior.

Ahh thank you very much, that is working as expected now :slight_smile: