Hi, I’m currently attempting to use numpyro to do some processing on the output of another tool which makes probability based predictions so it is treated like a (deterministic) Markov process with perturbations that I try to extract some meaning from. The issue is that whether I use SVI or DiscreteHMCGibbs for instance, it doesn’t pick up the switch_t (which is the latent variable of interest). It’s probably something simple, code below. I suspect the discrete latent variable is just marginalised out but then SVI doesn’t seem to actively learn any of the parameters either. Thanks in advance for any pointers.
def FLARE_HMM(sequences, n_populations,device):
num_sequences, max_length, data_dim = sequences.shape
key = random.PRNGKey(0)
#parameters for the beta distribution for the number of switches
concentration0 = np.param("conc0", init_value=0.002,constraint=dist.constraints.positive)
concentration1 = np.param("conc1", init_value=1,constraint=dist.constraints.positive)
sigma = np.param("sigma", random.beta(key=key,a=1,b=10))
with np.handlers.seed(rng_seed=23):
haplotype_transition = jnp.transpose(np.param("transition",uniform_init(n_populations,n_populations),constraint=dist.constraints.simplex),(1,0))
#turn transition into a double transition matrix
haplotype_transition1 = jnp.concatenate([haplotype_transition,jnp.zeros((n_populations,n_populations))],1)
haplotype_transition2 = jnp.concatenate([jnp.zeros((n_populations, n_populations)),haplotype_transition], 1)
final_transition = jnp.concatenate([haplotype_transition1,haplotype_transition2],0)
sw_fact = np.param("switch_factor1", lambda rng_key: dist.Beta(concentration0,concentration1).sample(rng_key,(1,)), rng_key=key)
def transition(carry, y):
key = random.PRNGKey(0)
m_curr, t = carry
with np.plate("sequences1", num_sequences,dim=-3):
switch_t = np.sample("switch",dist.Bernoulli(probs=sw_fact), rng_key=key, infer= {"enumerate": "parallel"})
with np.plate("sequences2", num_sequences, dim=-2):
switch_matrix_t = switch_t * jnp.repeat(jnp.expand_dims(jnp.eye(2 * n_populations),0),switch_t.shape[0],axis=0) + (1 - switch_t) * jnp.repeat(jnp.expand_dims(
jnp.diag(jnp.ones(n_populations), n_populations) + jnp.diag(jnp.ones(n_populations),
-n_populations),0),switch_t.shape[0],axis=0)
transition_matrix = final_transition
m_curr = jnp.einsum("ij,ikj -> ik", jnp.einsum("ij, kj -> ki", transition_matrix, m_curr),switch_matrix_t)
print("m_curr.shape: ", m_curr.shape)
print("y.shape: ",y.shape)
out = np.sample("y",dist.Normal(m_curr,sigma).to_event(1),obs=y)
print(out.shape)
return (m_curr, t+1),None
#use scan to move along the Markov process so we need to init first
with np.handlers.seed(rng_seed = 23):
m_0_0 = np.param('m_0_0',uniform_init(n_populations),constraint=dist.constraints.simplex)
m_0_1 = np.param('m_0_1',uniform_init(n_populations), constraint=dist.constraints.simplex)
m_0 = jnp.concatenate([m_0_0,m_0_1],axis=0)
m_0 = jnp.expand_dims(m_0,0)
m_0 = jnp.repeat(m_0,num_sequences,axis=0)
scan(transition,(m_0,0),jnp.swapaxes(sequences, 0, 1))