I would like to code the Markov Modulated Poisson Process described in “Fast and Flexible Temporal Point Processes with Triangular Maps” in Appendix Section B.1 (https://arxiv.org/pdf/2006.12631).
I consider the prior to be
\begin{align}
p(\mathbf{t},\mathbf{s}|\mathbf{\pi},\mathbf{A}) p(\mathbf{\pi},\mathbf{A})
\end{align}
where \mathbf{(t,s)} are the the jump times and states of the underlying Markov Jump Process, \mathbf{A} is the transition rates between states in the MJP, and \pi is the distribution of the initial state that the MJP starts in.
The likelihood is the poisson process, which intensity function depends on the state s at time t
\begin{align}
p(\mathbf{o|t,s,\lambda})
\end{align}
and \mathbf{o} is the number of event arrivals in the interval [t_{i-1},t_{i}], and \mathbf{\lambda} are the rates of the Poisson Process depending on the state you are in for the MJP. You observe some stream of events occur over some interval from [0,T], and all the events happen within this window.
When I code the joint distribution, the problem I face is that a while loop seems necessary. I say this because when writing the generative process for NumPyro (to specify the joint distribution) you have to continue to sample the jump times UNTIL t>T, then you stop sampling.
The rough outline of code that I want (but breaks because of the while loop is:
def model(arrival_times,T,num_states):
# The diagonal elements of the rate matrix are NOT sampled/needed, so should not be included in log-likelihood computation
mask = jnp.logical_not(np.eye(num_states, dtype=bool))
# Sample only off-diagonal elements using mask
transition_rates = numpyro.sample(
"transition_rates",
dist.Exponential(1.0).expand([num_states, num_states]).mask(mask).to_event(2)
)
# Compute the diagonal elements to satisfy the Q-matrix constraint (generator)
q_without_diagonal = jnp.tril(transition_rates,k=-1) + jnp.triu(transition_rates,k=1)
q_only_diagonal = -jnp.diagflat(jnp.sum(q_without_diagonal, axis=-1))
# generator matrix
q_matrix = q_only_diagonal + q_without_diagonal
# numpyro.deterministic("generator", q_matrix)
# poisson process rates (each state has a unique rate)
poisson_rates = numpyro.sample(
"poisson_rates",
dist.Exponential(1.0).expand([N]).to_event(1)
)
# initial state probability
initial_prob = numpyro.sample("initial_prob",dist.Dirichlet(jnp.ones(N)))
# current state (the first state)
current_state = numpyro.sample("current_state_0",dist.Categorical(initial_prob))
exit_rate = -q_matrix[current_state,current_state]
transition_rates_by_state = q_matrix[current_state, :]
transition_probs = jnp.where(transition_rates_by_state>0,transition_rates_by_state,0) / exit_rate + 1e-10
time_index = 0
current_time = 0
log_likelihood = 0
# continue sampling / generating until the markov jump process generates times outisde the observation window
while (current_time < T):
next_state = numpyro.sample(f"current_state_{time_index+1}",dist.Categorical(transition_probs))
# stay in current state for a given hold time, then transition to next state
hold_time = numpyro.sample(f"waiting_time_{time_index}", dist.Exponential(-q_matrix[current_state,current_state]))
current_poisson_rate = poisson_rates[current_state]
# probability of what is the next state
exit_rate = -q_matrix[current_state,current_state]
transition_rates_by_state = q_matrix[current_state, :]
transition_probs = jnp.where(transition_rates_by_state>0,transition_rates_by_state,0) / exit_rate + 1e-10
# increment current time
current_time = current_time + hold_time
time_index = time_index + 1
# ensure the end time is less than the observation window
end_time = jnp.clip(current_time,0,T)
start_time = jnp.clip(current_time-hold_time,0,T)
# see the number of arrivals / events within the interval t_i-1 t_i
N_arrivals = jnp.sum(jnp.where(jnp.logical_and(arrival_times<end_time,arrival_times>=start_time),1.0,0.0))
# log probability of poisson process
log_term = jnp.multiply(N_arrivals,jnp.log(current_poisson_rate)) - jnp.multiply(current_poisson_rate,end_time-start_time)
log_likelihood = log_likelihood + log_term*flag
# current state is next state
current_state = next_state
numpyro.factor("log_likelihood",log_likelihood)
I am using the DiscreteHMCGibbs
sampler as it does not seem to perform enumeration of the discrete random variables.
Any help would be incredible.