Using discrete latent variables inside scan

Hi, I’m currently attempting to use numpyro to do some processing on the output of another tool which makes probability based predictions so it is treated like a (deterministic) Markov process with perturbations that I try to extract some meaning from. The issue is that whether I use SVI or DiscreteHMCGibbs for instance, it doesn’t pick up the switch_t (which is the latent variable of interest). It’s probably something simple, code below. I suspect the discrete latent variable is just marginalised out but then SVI doesn’t seem to actively learn any of the parameters either. Thanks in advance for any pointers.

def FLARE_HMM(sequences, n_populations,device):
    num_sequences, max_length, data_dim = sequences.shape

    key = random.PRNGKey(0)

    #parameters for the beta distribution for the number of switches

    concentration0 = np.param("conc0", init_value=0.002,constraint=dist.constraints.positive)

    concentration1 = np.param("conc1", init_value=1,constraint=dist.constraints.positive)

    sigma = np.param("sigma", random.beta(key=key,a=1,b=10))


    with np.handlers.seed(rng_seed=23):
        haplotype_transition = jnp.transpose(np.param("transition",uniform_init(n_populations,n_populations),constraint=dist.constraints.simplex),(1,0))
        #turn transition into a double transition matrix
        haplotype_transition1 = jnp.concatenate([haplotype_transition,jnp.zeros((n_populations,n_populations))],1)
        haplotype_transition2 = jnp.concatenate([jnp.zeros((n_populations, n_populations)),haplotype_transition], 1)
        final_transition = jnp.concatenate([haplotype_transition1,haplotype_transition2],0)

        sw_fact = np.param("switch_factor1", lambda rng_key: dist.Beta(concentration0,concentration1).sample(rng_key,(1,)), rng_key=key)


    def transition(carry, y):
        key = random.PRNGKey(0)
        m_curr, t = carry

        with np.plate("sequences1", num_sequences,dim=-3):


            switch_t = np.sample("switch",dist.Bernoulli(probs=sw_fact), rng_key=key, infer= {"enumerate": "parallel"})

        with np.plate("sequences2", num_sequences, dim=-2):

            
            switch_matrix_t = switch_t * jnp.repeat(jnp.expand_dims(jnp.eye(2 * n_populations),0),switch_t.shape[0],axis=0) + (1 - switch_t) * jnp.repeat(jnp.expand_dims(
                        jnp.diag(jnp.ones(n_populations), n_populations) + jnp.diag(jnp.ones(n_populations),
                                                                                    -n_populations),0),switch_t.shape[0],axis=0)
            transition_matrix = final_transition
            m_curr = jnp.einsum("ij,ikj -> ik", jnp.einsum("ij, kj -> ki", transition_matrix, m_curr),switch_matrix_t)
            print("m_curr.shape: ", m_curr.shape)
            print("y.shape: ",y.shape)

            out = np.sample("y",dist.Normal(m_curr,sigma).to_event(1),obs=y)
            print(out.shape)

        return (m_curr, t+1),None

    #use scan to move along the Markov process so we need to init first
    with np.handlers.seed(rng_seed = 23):
        m_0_0 = np.param('m_0_0',uniform_init(n_populations),constraint=dist.constraints.simplex)
        m_0_1 = np.param('m_0_1',uniform_init(n_populations), constraint=dist.constraints.simplex)
    m_0 = jnp.concatenate([m_0_0,m_0_1],axis=0)
    m_0 = jnp.expand_dims(m_0,0)
    m_0 = jnp.repeat(m_0,num_sequences,axis=0)
    scan(transition,(m_0,0),jnp.swapaxes(sequences, 0, 1))