Building a state-space model from which I can sample and interact with inner components

Hi all,
I’ve been working on state-space models (i.e. general space HMMs) for a long time, particularly sequential variational inference and sequential monte carlo smoothing methods. So far I’ve been working with Jax building my own Python functional-style objects that I can interact with, e.g. I’d build something like a parametric transition kernel (like applying a deterministic function under some random noise).

class AdditiveNoiseKernel: 
 def __init__(self, map, noise_dist, noise_fn): = map 
  self.noise_dist = noise_dist
 def apply_map(self, x, map_params):
  return, **map_params)
 def sample(self, key, x, kernel_params): 
  noise_sample = self.noise_dist.sample(key, **kernel_params['dist'])
  return self.apply_map(x, **kernel_params['map']) + noise_sample
 def logpdf(x,kernel_params): 
 def get_random_kernel_params(self, key):

And from these building blocks I build my entire models with something like

class SSM: 
 def __init__(self, state_dim, obs_dim, transition_kernel, emission_kernel, init_dist): 
  # get kernels as member variables etc 
 def sample_joint_sequences(self, key, seq_length, num_seqs): 
  # stuff involving the transition kernels and emission kernels with lax.scan and jax.vmap
 def run_bootstrap_smc_filter(self, key, num_samples):

I want to convert this into Numpyro code but I’m not sure that the philosophy of the framework is intended for this, or maybe I’m missing the point. Where I’m lost is that I’ve been handling most of the randomness and building kind of “probabilistic” objects myself and I feel like this should be instead dealt with Pyro primitives and distribibutions, etc, but at the same time it’s unclear to me how to achieve fairly simplistic tasks with Pyro, like sampling toy data from the SSM once the model function has been defined.

Thanks in advance!

Your implementation of AdditiveNoiseKernel looks good. Maybe moving kernel_params to the constructor __init__ method?