MCMC for Markov Modulated Poisson Processes

I would like to code the Markov Modulated Poisson Process described in “Fast and Flexible Temporal Point Processes with Triangular Maps” in Appendix Section B.1 (https://arxiv.org/pdf/2006.12631).

I consider the prior to be
\begin{align} p(\mathbf{t},\mathbf{s}|\mathbf{\pi},\mathbf{A}) p(\mathbf{\pi},\mathbf{A}) \end{align}
where \mathbf{(t,s)} are the the jump times and states of the underlying Markov Jump Process, \mathbf{A} is the transition rates between states in the MJP, and \pi is the distribution of the initial state that the MJP starts in.

The likelihood is the poisson process, which intensity function depends on the state s at time t
\begin{align} p(\mathbf{o|t,s,\lambda}) \end{align}
and \mathbf{o} is the number of event arrivals in the interval [t_{i-1},t_{i}], and \mathbf{\lambda} are the rates of the Poisson Process depending on the state you are in for the MJP. You observe some stream of events occur over some interval from [0,T], and all the events happen within this window.

When I code the joint distribution, the problem I face is that a while loop seems necessary. I say this because when writing the generative process for NumPyro (to specify the joint distribution) you have to continue to sample the jump times UNTIL t>T, then you stop sampling.

The rough outline of code that I want (but breaks because of the while loop is:

def model(arrival_times,T,num_states):
    # The diagonal elements of the rate matrix are NOT sampled/needed, so should not be included in log-likelihood computation
    mask = jnp.logical_not(np.eye(num_states, dtype=bool))
    
    # Sample only off-diagonal elements using mask
    transition_rates = numpyro.sample(
        "transition_rates",
        dist.Exponential(1.0).expand([num_states, num_states]).mask(mask).to_event(2)
    )
    
    # Compute the diagonal elements to satisfy the Q-matrix constraint (generator)
    q_without_diagonal = jnp.tril(transition_rates,k=-1) + jnp.triu(transition_rates,k=1)
    q_only_diagonal = -jnp.diagflat(jnp.sum(q_without_diagonal, axis=-1))

    # generator matrix
    q_matrix = q_only_diagonal + q_without_diagonal
    
      
    # numpyro.deterministic("generator", q_matrix)

    # poisson process rates (each state has a unique rate)
    poisson_rates = numpyro.sample(
        "poisson_rates",
        dist.Exponential(1.0).expand([N]).to_event(1)
    )


    # initial state probability
    initial_prob = numpyro.sample("initial_prob",dist.Dirichlet(jnp.ones(N)))

    # current state (the first state)
    current_state = numpyro.sample("current_state_0",dist.Categorical(initial_prob))

    exit_rate = -q_matrix[current_state,current_state]
    transition_rates_by_state = q_matrix[current_state, :]
    transition_probs = jnp.where(transition_rates_by_state>0,transition_rates_by_state,0) / exit_rate + 1e-10

    time_index = 0
    current_time = 0
    log_likelihood = 0
    # continue sampling / generating until the markov jump process generates times outisde the observation window
    while (current_time < T):
      next_state = numpyro.sample(f"current_state_{time_index+1}",dist.Categorical(transition_probs))

      # stay in current state for a given hold time, then transition to next state
      hold_time = numpyro.sample(f"waiting_time_{time_index}", dist.Exponential(-q_matrix[current_state,current_state]))
      current_poisson_rate = poisson_rates[current_state]

      # probability of what is the next state
      exit_rate = -q_matrix[current_state,current_state]
      transition_rates_by_state = q_matrix[current_state, :]
      transition_probs = jnp.where(transition_rates_by_state>0,transition_rates_by_state,0) / exit_rate + 1e-10
  

      # increment current time
      current_time = current_time + hold_time
      time_index = time_index + 1

      # ensure the end time is less than the observation window
      end_time = jnp.clip(current_time,0,T)
      start_time = jnp.clip(current_time-hold_time,0,T)

      # see the number of arrivals / events within the interval t_i-1 t_i
      N_arrivals = jnp.sum(jnp.where(jnp.logical_and(arrival_times<end_time,arrival_times>=start_time),1.0,0.0))
      
      # log probability of poisson process
      log_term = jnp.multiply(N_arrivals,jnp.log(current_poisson_rate)) - jnp.multiply(current_poisson_rate,end_time-start_time)
  
      log_likelihood = log_likelihood + log_term*flag

      # current state is next state
      current_state = next_state



    numpyro.factor("log_likelihood",log_likelihood)

I am using the DiscreteHMCGibbs sampler as it does not seem to perform enumeration of the discrete random variables.

Any help would be incredible.