Numpyro producing the exact same value for parameters for every sample

I have written the following numpyro model which models a gaussian process and then performs binomial regression. The issue i am having is that I get the same values for the parameters for every single sample (almost as if something is wrong with the random key thats been passed). Any idea why this may be ?

print(mcmc.get_samples())
{'b0': Array([1.9732423, 1.9732423, 1.9732423, 1.9732423, 1.9732423, 1.9732423,
        1.9732423, 1.9732423, 1.9732423, 1.9732423], dtype=float32),
 'f': Array([[-1.8119389, -1.4418964,  0.5465225, ..., -1.9745315,  1.8767447,
         -1.0638568],
        [-1.8119389, -1.4418964,  0.5465225, ..., -1.9745315,  1.8767447,
         -1.0638568],
        [-1.8119389, -1.4418964,  0.5465225, ..., -1.9745315,  1.8767447,
         -1.0638568],
        ...,
        [2.0822305e-03, 9.9998713e-01, 4.5200054e-02, 2.7854379e-14,
         5.2529750e-03, 1.0000000e+00, 6.3772571e-10, 9.9999976e-01,
         7.6609455e-02, 9.9984181e-01, 7.4886680e-01, 9.9729437e-01,
         4.9855632e-01, 9.6825278e-01, 9.9998069e-01, 2.7288701e-02,
         5.2043241e-01, 9.9281102e-01, 4.9969818e-02, 2.4912173e-01,
         9.9999297e-01, 9.9981600e-01, 9.9959320e-01, 9.9998963e-01,
         2.8445052e-03, 9.8520803e-01, 9.9631631e-01, 7.5159292e-03,
         9.8712730e-01, 9.9997675e-01, 9.8993134e-01, 2.1812723e-04,
         7.1244460e-01, 9.9976546e-01, 4.7712356e-02, 9.9909592e-01,
         8.7383834e-10, 8.1159353e-01, 9.2745048e-01, 7.1263969e-08,
         9.9783307e-01, 9.9707043e-01, 9.6637094e-01, 6.5089977e-01,
         9.1920555e-02, 6.7142791e-01, 2.5963061e-04, 9.3706262e-01,
         5.6199029e-02, 9.8684663e-01, 2.4630034e-03, 2.2057141e-01,
         2.7173051e-01, 6.1513454e-01, 4.1309121e-04, 8.0895321e-03,
         9.9967158e-01, 6.6662908e-01],
        [2.0822305e-03, 9.9998713e-01, 4.5200054e-02, 2.7854379e-14,
         5.2529750e-03, 1.0000000e+00, 6.3772571e-10, 9.9999976e-01,
         7.6609455e-02, 9.9984181e-01, 7.4886680e-01, 9.9729437e-01,
         4.9855632e-01, 9.6825278e-01, 9.9998069e-01, 2.7288701e-02,
         5.2043241e-01, 9.9281102e-01, 4.9969818e-02, 2.4912173e-01,
         9.9999297e-01, 9.9981600e-01, 9.9959320e-01, 9.9998963e-01,
         2.8445052e-03, 9.8520803e-01, 9.9631631e-01, 7.5159292e-03,
         9.8712730e-01, 9.9997675e-01, 9.8993134e-01, 2.1812723e-04,
         7.1244460e-01, 9.9976546e-01, 4.7712356e-02, 9.9909592e-01,
         8.7383834e-10, 8.1159353e-01, 9.2745048e-01, 7.1263969e-08,
         9.9783307e-01, 9.9707043e-01, 9.6637094e-01, 6.5089977e-01,
         9.1920555e-02, 6.7142791e-01, 2.5963061e-04, 9.3706262e-01,
         5.6199029e-02, 9.8684663e-01, 2.4630034e-03, 2.2057141e-01,
         2.7173051e-01, 6.1513454e-01, 4.1309121e-04, 8.0895321e-03,
         9.9967158e-01, 6.6662908e-01]], dtype=float32)}

Note how the same values have been produced for all samples.

The Model

def prev_model_gp_aggr_v2(
    n_specimens_lo:Float[Array, "n_regions"],
    n_specimens_hi:Float[Array, "n_regions"],
    x:Float[Array, "n_grids_pts 2"],
    gp_kernel,
    noise:Float,
    jitter:Float,
    M_lo:Float[Array, "n_regions n_grid_pts"],
    M_hi:Float[Array, "n_regions n_grid_pts"],
    tested_positive:Union[None] = None
    ):

    # Sample values from distribution for hyperparams
    kernel_length = numpyro.sample("kernel_length", dist.InverseGamma(3,3))
    kernel_var = numpyro.sample("kernel_var", dist.HalfNormal(0.05))

    # Compute Kernel
    k = gp_kernel(x,x,kernel_var,kernel_length, noise, jitter)
    # Smaple positive cases
    f = numpyro.sample(
        "f",
        dist.MultivariateNormal(loc = jnp.zeros(x.shape[0]), covariance_matrix = k),
        obs = None
    )

    # Aggregate for all low/high points
    gp_aggr_lo:Float[Array, "n_regions n_samples"] = numpyro.deterministic("gp_aggr_lo", M_g(M_lo,f)) #e.g. (9,) <- note this is the shape for one sample
    gp_aggr_hi:Float[Array, "n_regions n_samples"] = numpyro.deterministic("gp_aggr_hi", M_g(M_hi,f)) #e.g. (49,)

    # Though we only want to use low level to predict high level, during training
    # we will incorporate both
    gp_aggr:Float[Array, "n_region_lo+n_region_hi n_samples"] = numpyro.deterministic("gp_aggr", jnp.concatenate([gp_aggr_lo,gp_aggr_hi])) #eg. (58,)

    # Fixed effects
    b0:Float = numpyro.sample("b0", dist.Normal(0,1))
    # linear predictor = Fixed effects (b0.X) + random effects (GP)
    lp:Float[Array, "n_regions_lo+hi n_samples"] = b0 + gp_aggr 

    #Theta represents the prevalence values (our target/y)
    theta:Float[Array, "n_regions_lo+hi n_samples"] = numpyro.deterministic("theta", jax.nn.sigmoid(lp))

    # We need to make tested positive cases array of shape : (n_regions_lo+hi) 
    # where hi values are NaN's. This is because we are planning to show the model
    # only low region tested cases to predict for the hi cases.
    tested_positive:Float[Array, "n_regions_lo+hi"] = jnp.pad(tested_positive, (0,M_hi.shape[0]), constant_values = 0.0) #e.g. (58,)
    tested_positive = jnp.where(tested_positive == 0, jnp.nan, tested_positive)
    tested_positive_mask = ~jnp.isnan(tested_positive) # e.g [True x 9 ... False x 49]
    
    # We show all tested case values, which is n in our Binomial regression model below
    tested_cases:Float[Array, "n_regions_lo+hi"] = jnp.concatenate([n_specimens_lo, n_specimens_hi], axis = 0) #e.g (58,)

    # We need to use numpyro.handlers.mask to make sure we can account for NaN values in observations
    with numpyro.handlers.mask(mask = tested_positive_mask):
        n_positive_obs = numpyro.sample(
            "n_positive_obs",
            dist.BinomialLogits(total_count=tested_cases, logits = lp),
            obs = tested_positive
        )
    return mcmc

How MCMC is run

 key,subkey = random.split(random.PRNGKey(0),2)
    #? You need to change these two
    n_warm = 10
    n_samples = 10
    mcmc = MCMC(
        NUTS(prev_model_gp_aggr_v2),
        num_warmup = n_warm,
        num_samples = n_samples
    )

    # Run MCMC
    start = time.time()
    mcmc.run(
        key,
        n_specimens_lo = n_specimens_lo,
        n_specimens_hi = n_specimens_hi,
        x = x,
        gp_kernel = exp_sq_kernel,
        noise = 1e-4,
        jitter = 1e-4,
        M_lo = pol_pt_lo,
        M_hi = pol_pt_hi,
        tested_positive = n_tested_pos_lo
    )
    end = time.time()
    t_elapsed_min = round((end-start)/60)
    print(f"Time takens for aggGP : {t_elapsed_min} minutes")

Any clue what I am doing wrong ?

I guess you can replace nan by something else. I think grad does not propagate correctly to the parameters.

1 Like

Thanks I tried it, but didnt seem like the case.
I just kept the missing values as zeros and it still returned constant values

tested_positive:Float[Array, "n_regions_lo+hi"] = jnp.pad(tested_positive, (0,M_hi.shape[0]), constant_values = 0.0) #e.g. (58,)

n_positive_obs = numpyro.sample(
    "n_positive_obs",
    dist.BinomialLogits(total_count=tested_cases, logits = lp),
    obs = tested_positive
)

You can try to replace numpyro.sample statements by some arguments and use jax.grad to check why grad is 0

def model(var1, var2, etc.):
   ...
   return output.sum()

jax.grad(model)(var1, var2)
1 Like

I think the issue arises from the GP, cause when I removed it just to test, I get different values for the hypeparameters

def prev_model(
    n_tests_lo,
    n_tests_hi,
    n_positive_lo,
    inference = False
    ):
    # Fixed effects 
    b0 = numpyro.sample("b0", dist.Normal(0, 1)) #(,)
    # Linear predictor 
    lp = b0 #(,)
    theta = numpyro.deterministic("theta", jax.nn.sigmoid(lp)) #(,)
    # This is the "n" binomial distribution  
    n_tests = jnp.concatenate([n_tests_lo, n_tests_hi]) #(58,)
    if not inference:
        #n_positive_hi = jnp.repeat(-100, n_tests_hi.shape[0]) #(49,)
        n_positive_hi = jnp.repeat(jnp.nan, n_tests_hi.shape[0]) #(49,)
        n_positive = jnp.concatenate([n_positive_lo, n_positive_hi], axis = 0) #(58,)
        #mask_regions = jnp.where(n_positive == -100, False, True)
        mask_regions = ~jnp.isnan(n_positive)

        with numpyro.handlers.mask(mask = mask_regions):
            n_positive_obs = numpyro.sample(
            "n_positive_obs",
            dist.BinomialLogits(total_count = n_tests, logits = lp),
            obs = n_positive
        )
    else:
        n_positive_obs = numpyro.sample(
            "n_positive_obs",
            dist.BinomialLogits(total_count = n_tests, logits = lp),
            obs = None 
        )
print(posterior_samples)
>>>
{'b0': Array([-2.105111 , -2.1085582, -2.1254044, -2.129344 , -2.1185746,
        -2.1192055, -2.1189632, -2.1181834, -2.129106 , -2.1122074],      dtype=float32),
 'theta': Array([0.10860106, 0.10826778, 0.10665207, 0.10627729, 0.10730453,
        0.10724411, 0.10726731, 0.10734201, 0.10629988, 0.10791597],      dtype=float32)}