Numpyro producing the exact same value for parameters for every sample

I have written the following numpyro model which models a gaussian process and then performs binomial regression. The issue i am having is that I get the same values for the parameters for every single sample (almost as if something is wrong with the random key thats been passed). Any idea why this may be ?

print(mcmc.get_samples())
{'b0': Array([1.9732423, 1.9732423, 1.9732423, 1.9732423, 1.9732423, 1.9732423,
        1.9732423, 1.9732423, 1.9732423, 1.9732423], dtype=float32),
 'f': Array([[-1.8119389, -1.4418964,  0.5465225, ..., -1.9745315,  1.8767447,
         -1.0638568],
        [-1.8119389, -1.4418964,  0.5465225, ..., -1.9745315,  1.8767447,
         -1.0638568],
        [-1.8119389, -1.4418964,  0.5465225, ..., -1.9745315,  1.8767447,
         -1.0638568],
        ...,
        [2.0822305e-03, 9.9998713e-01, 4.5200054e-02, 2.7854379e-14,
         5.2529750e-03, 1.0000000e+00, 6.3772571e-10, 9.9999976e-01,
         7.6609455e-02, 9.9984181e-01, 7.4886680e-01, 9.9729437e-01,
         4.9855632e-01, 9.6825278e-01, 9.9998069e-01, 2.7288701e-02,
         5.2043241e-01, 9.9281102e-01, 4.9969818e-02, 2.4912173e-01,
         9.9999297e-01, 9.9981600e-01, 9.9959320e-01, 9.9998963e-01,
         2.8445052e-03, 9.8520803e-01, 9.9631631e-01, 7.5159292e-03,
         9.8712730e-01, 9.9997675e-01, 9.8993134e-01, 2.1812723e-04,
         7.1244460e-01, 9.9976546e-01, 4.7712356e-02, 9.9909592e-01,
         8.7383834e-10, 8.1159353e-01, 9.2745048e-01, 7.1263969e-08,
         9.9783307e-01, 9.9707043e-01, 9.6637094e-01, 6.5089977e-01,
         9.1920555e-02, 6.7142791e-01, 2.5963061e-04, 9.3706262e-01,
         5.6199029e-02, 9.8684663e-01, 2.4630034e-03, 2.2057141e-01,
         2.7173051e-01, 6.1513454e-01, 4.1309121e-04, 8.0895321e-03,
         9.9967158e-01, 6.6662908e-01],
        [2.0822305e-03, 9.9998713e-01, 4.5200054e-02, 2.7854379e-14,
         5.2529750e-03, 1.0000000e+00, 6.3772571e-10, 9.9999976e-01,
         7.6609455e-02, 9.9984181e-01, 7.4886680e-01, 9.9729437e-01,
         4.9855632e-01, 9.6825278e-01, 9.9998069e-01, 2.7288701e-02,
         5.2043241e-01, 9.9281102e-01, 4.9969818e-02, 2.4912173e-01,
         9.9999297e-01, 9.9981600e-01, 9.9959320e-01, 9.9998963e-01,
         2.8445052e-03, 9.8520803e-01, 9.9631631e-01, 7.5159292e-03,
         9.8712730e-01, 9.9997675e-01, 9.8993134e-01, 2.1812723e-04,
         7.1244460e-01, 9.9976546e-01, 4.7712356e-02, 9.9909592e-01,
         8.7383834e-10, 8.1159353e-01, 9.2745048e-01, 7.1263969e-08,
         9.9783307e-01, 9.9707043e-01, 9.6637094e-01, 6.5089977e-01,
         9.1920555e-02, 6.7142791e-01, 2.5963061e-04, 9.3706262e-01,
         5.6199029e-02, 9.8684663e-01, 2.4630034e-03, 2.2057141e-01,
         2.7173051e-01, 6.1513454e-01, 4.1309121e-04, 8.0895321e-03,
         9.9967158e-01, 6.6662908e-01]], dtype=float32)}

Note how the same values have been produced for all samples.

The Model

def prev_model_gp_aggr_v2(
    n_specimens_lo:Float[Array, "n_regions"],
    n_specimens_hi:Float[Array, "n_regions"],
    x:Float[Array, "n_grids_pts 2"],
    gp_kernel,
    noise:Float,
    jitter:Float,
    M_lo:Float[Array, "n_regions n_grid_pts"],
    M_hi:Float[Array, "n_regions n_grid_pts"],
    tested_positive:Union[None] = None
    ):

    # Sample values from distribution for hyperparams
    kernel_length = numpyro.sample("kernel_length", dist.InverseGamma(3,3))
    kernel_var = numpyro.sample("kernel_var", dist.HalfNormal(0.05))

    # Compute Kernel
    k = gp_kernel(x,x,kernel_var,kernel_length, noise, jitter)
    # Smaple positive cases
    f = numpyro.sample(
        "f",
        dist.MultivariateNormal(loc = jnp.zeros(x.shape[0]), covariance_matrix = k),
        obs = None
    )

    # Aggregate for all low/high points
    gp_aggr_lo:Float[Array, "n_regions n_samples"] = numpyro.deterministic("gp_aggr_lo", M_g(M_lo,f)) #e.g. (9,) <- note this is the shape for one sample
    gp_aggr_hi:Float[Array, "n_regions n_samples"] = numpyro.deterministic("gp_aggr_hi", M_g(M_hi,f)) #e.g. (49,)

    # Though we only want to use low level to predict high level, during training
    # we will incorporate both
    gp_aggr:Float[Array, "n_region_lo+n_region_hi n_samples"] = numpyro.deterministic("gp_aggr", jnp.concatenate([gp_aggr_lo,gp_aggr_hi])) #eg. (58,)

    # Fixed effects
    b0:Float = numpyro.sample("b0", dist.Normal(0,1))
    # linear predictor = Fixed effects (b0.X) + random effects (GP)
    lp:Float[Array, "n_regions_lo+hi n_samples"] = b0 + gp_aggr 

    #Theta represents the prevalence values (our target/y)
    theta:Float[Array, "n_regions_lo+hi n_samples"] = numpyro.deterministic("theta", jax.nn.sigmoid(lp))

    # We need to make tested positive cases array of shape : (n_regions_lo+hi) 
    # where hi values are NaN's. This is because we are planning to show the model
    # only low region tested cases to predict for the hi cases.
    tested_positive:Float[Array, "n_regions_lo+hi"] = jnp.pad(tested_positive, (0,M_hi.shape[0]), constant_values = 0.0) #e.g. (58,)
    tested_positive = jnp.where(tested_positive == 0, jnp.nan, tested_positive)
    tested_positive_mask = ~jnp.isnan(tested_positive) # e.g [True x 9 ... False x 49]
    
    # We show all tested case values, which is n in our Binomial regression model below
    tested_cases:Float[Array, "n_regions_lo+hi"] = jnp.concatenate([n_specimens_lo, n_specimens_hi], axis = 0) #e.g (58,)

    # We need to use numpyro.handlers.mask to make sure we can account for NaN values in observations
    with numpyro.handlers.mask(mask = tested_positive_mask):
        n_positive_obs = numpyro.sample(
            "n_positive_obs",
            dist.BinomialLogits(total_count=tested_cases, logits = lp),
            obs = tested_positive
        )
    return mcmc

How MCMC is run

 key,subkey = random.split(random.PRNGKey(0),2)
    #? You need to change these two
    n_warm = 10
    n_samples = 10
    mcmc = MCMC(
        NUTS(prev_model_gp_aggr_v2),
        num_warmup = n_warm,
        num_samples = n_samples
    )

    # Run MCMC
    start = time.time()
    mcmc.run(
        key,
        n_specimens_lo = n_specimens_lo,
        n_specimens_hi = n_specimens_hi,
        x = x,
        gp_kernel = exp_sq_kernel,
        noise = 1e-4,
        jitter = 1e-4,
        M_lo = pol_pt_lo,
        M_hi = pol_pt_hi,
        tested_positive = n_tested_pos_lo
    )
    end = time.time()
    t_elapsed_min = round((end-start)/60)
    print(f"Time takens for aggGP : {t_elapsed_min} minutes")

Any clue what I am doing wrong ?

I guess you can replace nan by something else. I think grad does not propagate correctly to the parameters.

1 Like

Thanks I tried it, but didnt seem like the case.
I just kept the missing values as zeros and it still returned constant values

tested_positive:Float[Array, "n_regions_lo+hi"] = jnp.pad(tested_positive, (0,M_hi.shape[0]), constant_values = 0.0) #e.g. (58,)

n_positive_obs = numpyro.sample(
    "n_positive_obs",
    dist.BinomialLogits(total_count=tested_cases, logits = lp),
    obs = tested_positive
)

You can try to replace numpyro.sample statements by some arguments and use jax.grad to check why grad is 0

def model(var1, var2, etc.):
   ...
   return output.sum()

jax.grad(model)(var1, var2)
1 Like

I think the issue arises from the GP, cause when I removed it just to test, I get different values for the hypeparameters

def prev_model(
    n_tests_lo,
    n_tests_hi,
    n_positive_lo,
    inference = False
    ):
    # Fixed effects 
    b0 = numpyro.sample("b0", dist.Normal(0, 1)) #(,)
    # Linear predictor 
    lp = b0 #(,)
    theta = numpyro.deterministic("theta", jax.nn.sigmoid(lp)) #(,)
    # This is the "n" binomial distribution  
    n_tests = jnp.concatenate([n_tests_lo, n_tests_hi]) #(58,)
    if not inference:
        #n_positive_hi = jnp.repeat(-100, n_tests_hi.shape[0]) #(49,)
        n_positive_hi = jnp.repeat(jnp.nan, n_tests_hi.shape[0]) #(49,)
        n_positive = jnp.concatenate([n_positive_lo, n_positive_hi], axis = 0) #(58,)
        #mask_regions = jnp.where(n_positive == -100, False, True)
        mask_regions = ~jnp.isnan(n_positive)

        with numpyro.handlers.mask(mask = mask_regions):
            n_positive_obs = numpyro.sample(
            "n_positive_obs",
            dist.BinomialLogits(total_count = n_tests, logits = lp),
            obs = n_positive
        )
    else:
        n_positive_obs = numpyro.sample(
            "n_positive_obs",
            dist.BinomialLogits(total_count = n_tests, logits = lp),
            obs = None 
        )
print(posterior_samples)
>>>
{'b0': Array([-2.105111 , -2.1085582, -2.1254044, -2.129344 , -2.1185746,
        -2.1192055, -2.1189632, -2.1181834, -2.129106 , -2.1122074],      dtype=float32),
 'theta': Array([0.10860106, 0.10826778, 0.10665207, 0.10627729, 0.10730453,
        0.10724411, 0.10726731, 0.10734201, 0.10629988, 0.10791597],      dtype=float32)}

Just for completeness managed to solve the issue, actually it wasn’t an error on numpyro model side but rather the data. The data was too dispersed (Some very high and some very low values for regions) and replacing Binomial Distribution with BetaBinomial worked. After this change the parameters values were different for different samples.