Sampling from a posterior with a weighted log-likelihood function

Hello, I’m trying to get to grips with numpyro . I’ve implemented a basic Bayesian neural network regression model, and can run NUTS to get samples of the posterior over the neural network weights.

def numpyro_model(Z, M, sd, sd_prior)

# Extract X and y
X = Z[:, :-1]
y = Z[:, -1].reshape(-1, 1)

# Get the input and output dimension, as well as the number of 
N, D = X.shape
P = y.shape[1]

# Sample input layer weights (Normal prior)
w1 = numpyro.sample("w1", dist.Normal(jnp.zeros((D, M)), sd_prior * jnp.ones((D, M))))
assert w1.shape == (D, M)

# Sample input layer biases (Normal prior)
b1 = numpyro.sample("b1", dist.Normal(jnp.zeros((M)), sd_prior * jnp.ones((M))))
assert b1.shape == (M, )

# Compute first layer activations
z1 = jnp.tanh((jnp.matmul(X, w1) + b1))
assert z1.shape == (N, M)

# Sample output layer weights (Normal prior)
w2 = numpyro.sample("w2", dist.Normal(jnp.zeros((M, P)), sd_prior * jnp.ones((M, P))))
assert w2.shape == (M, P)

# Sample output layer bias (Normal prior)
b2 = numpyro.sample("b2", dist.Normal(jnp.zeros((P)), sd_prior * jnp.ones((P))))
assert b2.shape == (P, )

# Compute output
z2 = jnp.matmul(z1, w2) + b2  # <= output of the neural network
assert z2.shape == (N, P)

if y is not None:
    assert z2.shape == y.shape

# Observe the data
with numpyro.plate("data", N):
    numpyro.sample("y", dist.Normal(z2, sd).to_event(1), obs = y)

def sample_posterior(model, num_samples, num_warmup, num_chains, rng_key, Z, M, sd, sd_prior):

start = time.time()
kernel = NUTS(model)
mcmc = MCMC(
    num_warmup = num_warmup,
    num_samples = num_samples,
    num_chains = num_chains
), Z, M, sd, sd_prior)
print("\nMCMC elapsed time:", time.time() - start)

mcmc_results = mcmc.get_samples()
return mcmc_results

Now, I want to replace the log-likelihood function with a weighted version. That is, each data log-likelihood is multiplied by some positive weight. Then I want to be able to pass a weight vector to the sampler and model, and then sample in the same way as above. I’d love some guidance how to go about doing this!

It is my current understnading that I will need to write some distribution class that inherits from Normal, but replace the logprob function with a weighted version. Then, replace numpyro.sample(“y”, dist.Normal(z2, sd).to_event(1), obs = y) with numpyro.sample(“y”, dist.WeightedNormal(z2, sd, wts).to_event(1), obs = y). Is this all that is required?

Thanks in advance!

you just need to use scale; see here:

1 Like

Amazing, this is exactly what I was looking for. Thank you!