I have written the following numpyro model which models a gaussian process and then performs binomial regression. The issue i am having is that I get the same values for the parameters for every single sample (almost as if something is wrong with the random key thats been passed). Any idea why this may be ?
print(mcmc.get_samples())
{'b0': Array([1.9732423, 1.9732423, 1.9732423, 1.9732423, 1.9732423, 1.9732423,
1.9732423, 1.9732423, 1.9732423, 1.9732423], dtype=float32),
'f': Array([[-1.8119389, -1.4418964, 0.5465225, ..., -1.9745315, 1.8767447,
-1.0638568],
[-1.8119389, -1.4418964, 0.5465225, ..., -1.9745315, 1.8767447,
-1.0638568],
[-1.8119389, -1.4418964, 0.5465225, ..., -1.9745315, 1.8767447,
-1.0638568],
...,
[2.0822305e-03, 9.9998713e-01, 4.5200054e-02, 2.7854379e-14,
5.2529750e-03, 1.0000000e+00, 6.3772571e-10, 9.9999976e-01,
7.6609455e-02, 9.9984181e-01, 7.4886680e-01, 9.9729437e-01,
4.9855632e-01, 9.6825278e-01, 9.9998069e-01, 2.7288701e-02,
5.2043241e-01, 9.9281102e-01, 4.9969818e-02, 2.4912173e-01,
9.9999297e-01, 9.9981600e-01, 9.9959320e-01, 9.9998963e-01,
2.8445052e-03, 9.8520803e-01, 9.9631631e-01, 7.5159292e-03,
9.8712730e-01, 9.9997675e-01, 9.8993134e-01, 2.1812723e-04,
7.1244460e-01, 9.9976546e-01, 4.7712356e-02, 9.9909592e-01,
8.7383834e-10, 8.1159353e-01, 9.2745048e-01, 7.1263969e-08,
9.9783307e-01, 9.9707043e-01, 9.6637094e-01, 6.5089977e-01,
9.1920555e-02, 6.7142791e-01, 2.5963061e-04, 9.3706262e-01,
5.6199029e-02, 9.8684663e-01, 2.4630034e-03, 2.2057141e-01,
2.7173051e-01, 6.1513454e-01, 4.1309121e-04, 8.0895321e-03,
9.9967158e-01, 6.6662908e-01],
[2.0822305e-03, 9.9998713e-01, 4.5200054e-02, 2.7854379e-14,
5.2529750e-03, 1.0000000e+00, 6.3772571e-10, 9.9999976e-01,
7.6609455e-02, 9.9984181e-01, 7.4886680e-01, 9.9729437e-01,
4.9855632e-01, 9.6825278e-01, 9.9998069e-01, 2.7288701e-02,
5.2043241e-01, 9.9281102e-01, 4.9969818e-02, 2.4912173e-01,
9.9999297e-01, 9.9981600e-01, 9.9959320e-01, 9.9998963e-01,
2.8445052e-03, 9.8520803e-01, 9.9631631e-01, 7.5159292e-03,
9.8712730e-01, 9.9997675e-01, 9.8993134e-01, 2.1812723e-04,
7.1244460e-01, 9.9976546e-01, 4.7712356e-02, 9.9909592e-01,
8.7383834e-10, 8.1159353e-01, 9.2745048e-01, 7.1263969e-08,
9.9783307e-01, 9.9707043e-01, 9.6637094e-01, 6.5089977e-01,
9.1920555e-02, 6.7142791e-01, 2.5963061e-04, 9.3706262e-01,
5.6199029e-02, 9.8684663e-01, 2.4630034e-03, 2.2057141e-01,
2.7173051e-01, 6.1513454e-01, 4.1309121e-04, 8.0895321e-03,
9.9967158e-01, 6.6662908e-01]], dtype=float32)}
Note how the same values have been produced for all samples.
The Model
def prev_model_gp_aggr_v2(
n_specimens_lo:Float[Array, "n_regions"],
n_specimens_hi:Float[Array, "n_regions"],
x:Float[Array, "n_grids_pts 2"],
gp_kernel,
noise:Float,
jitter:Float,
M_lo:Float[Array, "n_regions n_grid_pts"],
M_hi:Float[Array, "n_regions n_grid_pts"],
tested_positive:Union[None] = None
):
# Sample values from distribution for hyperparams
kernel_length = numpyro.sample("kernel_length", dist.InverseGamma(3,3))
kernel_var = numpyro.sample("kernel_var", dist.HalfNormal(0.05))
# Compute Kernel
k = gp_kernel(x,x,kernel_var,kernel_length, noise, jitter)
# Smaple positive cases
f = numpyro.sample(
"f",
dist.MultivariateNormal(loc = jnp.zeros(x.shape[0]), covariance_matrix = k),
obs = None
)
# Aggregate for all low/high points
gp_aggr_lo:Float[Array, "n_regions n_samples"] = numpyro.deterministic("gp_aggr_lo", M_g(M_lo,f)) #e.g. (9,) <- note this is the shape for one sample
gp_aggr_hi:Float[Array, "n_regions n_samples"] = numpyro.deterministic("gp_aggr_hi", M_g(M_hi,f)) #e.g. (49,)
# Though we only want to use low level to predict high level, during training
# we will incorporate both
gp_aggr:Float[Array, "n_region_lo+n_region_hi n_samples"] = numpyro.deterministic("gp_aggr", jnp.concatenate([gp_aggr_lo,gp_aggr_hi])) #eg. (58,)
# Fixed effects
b0:Float = numpyro.sample("b0", dist.Normal(0,1))
# linear predictor = Fixed effects (b0.X) + random effects (GP)
lp:Float[Array, "n_regions_lo+hi n_samples"] = b0 + gp_aggr
#Theta represents the prevalence values (our target/y)
theta:Float[Array, "n_regions_lo+hi n_samples"] = numpyro.deterministic("theta", jax.nn.sigmoid(lp))
# We need to make tested positive cases array of shape : (n_regions_lo+hi)
# where hi values are NaN's. This is because we are planning to show the model
# only low region tested cases to predict for the hi cases.
tested_positive:Float[Array, "n_regions_lo+hi"] = jnp.pad(tested_positive, (0,M_hi.shape[0]), constant_values = 0.0) #e.g. (58,)
tested_positive = jnp.where(tested_positive == 0, jnp.nan, tested_positive)
tested_positive_mask = ~jnp.isnan(tested_positive) # e.g [True x 9 ... False x 49]
# We show all tested case values, which is n in our Binomial regression model below
tested_cases:Float[Array, "n_regions_lo+hi"] = jnp.concatenate([n_specimens_lo, n_specimens_hi], axis = 0) #e.g (58,)
# We need to use numpyro.handlers.mask to make sure we can account for NaN values in observations
with numpyro.handlers.mask(mask = tested_positive_mask):
n_positive_obs = numpyro.sample(
"n_positive_obs",
dist.BinomialLogits(total_count=tested_cases, logits = lp),
obs = tested_positive
)
return mcmc
How MCMC is run
key,subkey = random.split(random.PRNGKey(0),2)
#? You need to change these two
n_warm = 10
n_samples = 10
mcmc = MCMC(
NUTS(prev_model_gp_aggr_v2),
num_warmup = n_warm,
num_samples = n_samples
)
# Run MCMC
start = time.time()
mcmc.run(
key,
n_specimens_lo = n_specimens_lo,
n_specimens_hi = n_specimens_hi,
x = x,
gp_kernel = exp_sq_kernel,
noise = 1e-4,
jitter = 1e-4,
M_lo = pol_pt_lo,
M_hi = pol_pt_hi,
tested_positive = n_tested_pos_lo
)
end = time.time()
t_elapsed_min = round((end-start)/60)
print(f"Time takens for aggGP : {t_elapsed_min} minutes")
Any clue what I am doing wrong ?