Dot_general error when uncentering a model

Hello,

The following model works:

def BQR_SSVS(tau=0.5, a=5, b=.04, beta_1=0.5, beta_2=0.5, gamma_1=3, gamma_2=100, X=None, y=None):
    
    T, K = X.shape
    
    # Deterministic
    theta            = (1-2*tau)/(tau*(1-tau))
    tau_star_squared = 2/(tau*(1-tau))
    c                = 10**(-5)
    
    # Non-Beta Priors
    sigma = numpyro.sample('sigma', dist.InverseGamma(a,1/b))
    
    # Beta Priors
    beta0  = numpyro.sample('beta0', dist.Normal(0, 1))
    pi_0  = numpyro.sample('pi_0', dist.Beta(beta_1,beta_2))
    with numpyro.plate("plate_beta", K):
        gamma       = numpyro.sample('gamma', dist.Bernoulli(pi_0))
        delta       = numpyro.sample('delta', dist.InverseGamma(gamma_1,gamma_2))
        sigma_beta  = jnp.sqrt((1-gamma)*c*delta + gamma*delta)
        beta        = numpyro.sample("beta", dist.Normal(0, sigma_beta))
           
    y_mean = beta0+jnp.matmul(X,beta)
    with numpyro.plate("plate_T",T):
        z  = numpyro.sample('z', dist.Exponential(1/sigma))
        sigma_obs = jnp.sqrt(tau_star_squared*sigma*z)
        y  = numpyro.sample("y", dist.Normal(y_mean+theta*z, sigma_obs),obs=y)

but when changing

beta        = numpyro.sample("beta", dist.Normal(0, sigma_beta))

to

unscaled_betas = numpyro.sample("unscaled_betas", dist.Normal(0.0, 1.0))
beta = numpyro.deterministic("betas", sigma_beta * unscaled_betas)

It throws the following error:


It seems like changing this line changes the dims of beta? But, why?

You can trace back the error to see the shapes of sigma beta and unscaled beta. Here enumeration might be happening and your code might not be able to handle a batch of gamma (when gamma is enumerated). Because your model seems not suitable for enumeration (the site y, which is outside of plate beta, depends on the site gamma), you can try inference algorithms like DiscreteHMCGibbs. :slight_smile:

Thanks @fehiepsi. May I check my understanding? Is this your suggestion:

def BQR_SSVS(tau=0.5, a=5, b=.04, beta_1=0.5, beta_2=0.5, gamma_1=3, gamma_2=100, X=None, y=None):
    
    T, K = X.shape
    
    # Deterministic
    theta            = (1-2*tau)/(tau*(1-tau))
    tau_star_squared = 2/(tau*(1-tau))
    c                = 10**(-5)
    
    # Non-Beta Priors
    sigma = numpyro.sample('sigma', dist.InverseGamma(a,1/b))
    z     = numpyro.sample('z', dist.Exponential((1/sigma)*jnp.ones(T)))
    
    # Beta Priors
    beta0          = numpyro.sample('beta0', dist.Normal(0, 1))
    pi_0           = numpyro.sample('pi_0', dist.Beta(beta_1,beta_2))
    gamma          = numpyro.sample('gamma', dist.Bernoulli(pi_0*jnp.ones(K)))
    delta          = numpyro.sample('delta', dist.InverseGamma(gamma_1*jnp.ones(K),gamma_2*jnp.ones(K)))
    sigma_beta     = jnp.sqrt((1-gamma)*c*delta + gamma*delta)
    unscaled_betas = numpyro.sample("unscaled_beta", dist.Normal(0.0, jnp.ones(K)))
    beta           = numpyro.deterministic("beta", sigma_beta * unscaled_betas)
    
    # Likelihood
    mean_function = beta0+jnp.matmul(X,beta)+theta*z
    sigma_obs     = jnp.sqrt(tau_star_squared*sigma*z)
    y             = numpyro.sample("y", dist.Normal(mean_function, sigma_obs),obs=y)
m_SSVS = DiscreteHMCGibbs(NUTS(BQR_SSVS,
                   max_tree_depth=trees,
                   target_accept_prob=target_acc_prob))
mcmc = MCMC(m_SSVS, num_warmup=n_warmup, num_samples=n_samples, num_chains=1)

mcmc.run(random.PRNGKey(0),X=X,y=y)
mcmc.print_summary(0.89,exclude_deterministic=False)

Yes, just a comment that it is better (more clarity) to use plate in your model, i.e. using your old model with DiscreteHMCGibbs.

Thanks! I tried it on both versions :slight_smile: And on both, I ended up getting nans for some gammas in the model…is this expected behavior? Did I forget an option on DiscreteHMCGibbs? Or do my model/priors maybe need tweaking?

I think this is expected behavior. Those gamma values are constant because the potential energy when switching them to the other values is large (i.e. unlikely to happen). When a variable is constant, its effective sample size and Rhat are NaNs.

Hmm, interesting. Makes sense, though. Thanks again!

Internally, DiscreteHMCGibbs just alternates one bernoulli variable randomly at a time, perform MH correction to accept/reject that proposal. The technique in this paper can be more efficient - it suggests a way to choose which bernoulli variable to update in the next step. But implementing it is a bit tricky. Hopefully, DiscreteHMCGibbs is already good for your problem.