Hi, I am currently working on a factor analysis model for single cell. Heres my model, its a simple z @ w model, where the w is under a horseshoe prior. I noticed that once I restrict my z to a row stochastic matrix(rows sum to 1) through a dirichlet prior, or logistic normal, my regularized horseshoe goes crazy, and my tau goes to very large number and my lambda_tilde approaches c. I have no idea why the model would prefer such a solution. (I dont think its cause of geometry or anything, and I tried reparam and MCMC as well and the tau goes to large numbers). Do anyone have a clue?
ls = X.sum(axis = 1, keepdims = True) # Additional prior on z # c = numpyro.sample("c", dist.InverseGamma(0.5, 0.5)) c = 1 with numpyro.plate("n_topics", 20): w_shape = (20, X.shape) tau = numpyro.sample("tau", dist.HalfCauchy(jnp.ones((20, 1))).to_event(1)) lambda_ = numpyro.sample("lambda_", dist.HalfCauchy(jnp.ones(w_shape)).to_event(1)) lambda_tilde = (c * tau**2 * lambda_**2) / (c + tau**2 * lambda_**2) w = numpyro.sample("w", dist.Normal(jnp.zeros(w_shape), jnp.sqrt(lambda_tilde)).to_event(1)) with numpyro.plate("n_cells", X.shape) as ind: z_shape = (2695, 20) # z = numpyro.sample('z', dist.Normal(jnp.zeros(z_shape), jnp.ones(z_shape)).to_event(1)) # z = jax.nn.softmax(z, axis = 1) z = numpyro.sample('z', dist.Dirichlet(jnp.ones(z_shape) * 0.1)) # z = numpyro.sample('z', dist.HalfNormal(jnp.ones(z_shape) * 10).to_event(1)) mean = jax.nn.softmax(z @ w, axis = 1) numpyro.sample("X", dist.Poisson(ls * mean + 1e-8).to_event(1), obs = X)
guide = AutoNormal(model)