Hi, I am currently working on a factor analysis model for single cell. Heres my model, its a simple z @ w model, where the w is under a horseshoe prior. I noticed that once I restrict my z to a row stochastic matrix(rows sum to 1) through a dirichlet prior, or logistic normal, my regularized horseshoe goes crazy, and my tau goes to very large number and my lambda_tilde approaches c. I have no idea why the model would prefer such a solution. (I dont think its cause of geometry or anything, and I tried reparam and MCMC as well and the tau goes to large numbers). Do anyone have a clue?
def model(X):
ls = X.sum(axis = 1, keepdims = True)
# Additional prior on z
# c = numpyro.sample("c", dist.InverseGamma(0.5, 0.5))
c = 1
with numpyro.plate("n_topics", 20):
w_shape = (20, X.shape[1])
tau = numpyro.sample("tau", dist.HalfCauchy(jnp.ones((20, 1))).to_event(1))
lambda_ = numpyro.sample("lambda_", dist.HalfCauchy(jnp.ones(w_shape)).to_event(1))
lambda_tilde = (c * tau**2 * lambda_**2) / (c + tau**2 * lambda_**2)
w = numpyro.sample("w", dist.Normal(jnp.zeros(w_shape), jnp.sqrt(lambda_tilde)).to_event(1))
with numpyro.plate("n_cells", X.shape[0]) as ind:
z_shape = (2695, 20)
# z = numpyro.sample('z', dist.Normal(jnp.zeros(z_shape), jnp.ones(z_shape)).to_event(1))
# z = jax.nn.softmax(z, axis = 1)
z = numpyro.sample('z', dist.Dirichlet(jnp.ones(z_shape) * 0.1))
# z = numpyro.sample('z', dist.HalfNormal(jnp.ones(z_shape) * 10).to_event(1))
mean = jax.nn.softmax(z @ w, axis = 1)
numpyro.sample("X", dist.Poisson(ls * mean + 1e-8).to_event(1), obs = X)
guide = AutoNormal(model)