From Pymc to Numpyro, nuts result with much higher std

Hi all,

I have a model which first I implemented in numpyro and then in pymc.
I ran the inference with NUTS sampling for both implementation and I get much higher standard deviations in theta parameters from numpyro inference. The model itself is very simple. I have also included the summary for infer results. I think I have issue with shapes or in jnp.einsum .
For both implementations I used 2000 draws with 1000 tuning steps and target_accept = 0.99. The mean posterior prediction on observation are quite similar on both.

I will appreciate if anyone can help.

C = 4
L = 8
x.shape = (2400,32)

Numpro model

def additive_model(L: int = None,
                   C: int = None,
                   x: DeviceArray = None, 
                   y: DeviceArray = None) -> DeviceArray:
    
    # Additive GP map: phi = theta_0 + dot(x, theta_lc)
    ## GP map parameters
    theta_0 = numpyro.sample('theta_0', dist.Normal(loc=0, scale=20))
    theta_lc = numpyro.sample('theta_lc', 
                              dist.Normal(loc=jnp.zeros((L*C)),
                                          scale=20*jnp.ones((L*C))))
    latent_val =  theta_0 + jnp.einsum('i,ji->j',theta_lc,x)
    phi = numpyro.deterministic('phi', latent_val)      
    # Pre-defined measurement process: log(1+e^\phi)
    g = jnp.log(1+jnp.exp(phi))   
    # Noise model
    sigma = numpyro.sample('sigma', dist.HalfNormal(scale=1))
               
    return numpyro.sample('yhat', dist.Normal(loc=g, scale=sigma), obs=y)

Pymc model

additive_model = pm.Model()
with additive_model:
    # Additive GP map: phi = theta_0 + theta_lc*x
    ## GP map parameters
    theta_0 = pm.Normal('theta_0', mu=0, sigma=20)
    theta_lc = pm.Normal('theta_lc', mu=0, sigma=20, shape=(L*C))
    ## Latent phenotype
    latent_val =  theta_0 + np.einsum('i,ji->j',theta_lc,x)
    phi = pm.Deterministic('phi', latent_val)
    # Pre-defined measurement process: log(1+e^\phi)
    g = np.log(1+np.exp(phi))   
    # Noise model
    sigma = pm.HalfNormal('sigma', sigma=1)               
    Y_hat= pm.Normal('yhat', mu = g, sigma = sigma, observed=y)

Numpyro parameters

Pymc parameters

hello @mahdik

two suggestions:

  • before making any assumptions about whether the two software packages are returning different results i’d run a lot more samples, eg. 5k warm-up and 30k post-warm-up. you can’t make reliable comparisons based on a handful of samples
  • target_accept = 0.99 is generally a really bad idea (it’ll lead to small steps) and so i’d recommend leaving this and other HMC/NUTS hyperparameters at their default values unless you’re an expert and understand how those algorithms work in detail
1 Like

Hello Martin @martinjankowiak. Thanks for the suggestions. Will try them out.