Hi all,
I’ve been working on state-space models (i.e. general space HMMs) for a long time, particularly sequential variational inference and sequential monte carlo smoothing methods. So far I’ve been working with Jax building my own Python functional-style objects that I can interact with, e.g. I’d build something like a parametric transition kernel (like applying a deterministic function under some random noise).
class AdditiveNoiseKernel:
def __init__(self, map, noise_dist, noise_fn):
self.map = map
self.noise_dist = noise_dist
def apply_map(self, x, map_params):
return self.map(x, **map_params)
def sample(self, key, x, kernel_params):
noise_sample = self.noise_dist.sample(key, **kernel_params['dist'])
return self.apply_map(x, **kernel_params['map']) + noise_sample
def logpdf(x,kernel_params):
...
def get_random_kernel_params(self, key):
...
And from these building blocks I build my entire models with something like
class SSM:
def __init__(self, state_dim, obs_dim, transition_kernel, emission_kernel, init_dist):
# get kernels as member variables etc
def sample_joint_sequences(self, key, seq_length, num_seqs):
# stuff involving the transition kernels and emission kernels with lax.scan and jax.vmap
def run_bootstrap_smc_filter(self, key, num_samples):
...
I want to convert this into Numpyro code but I’m not sure that the philosophy of the framework is intended for this, or maybe I’m missing the point. Where I’m lost is that I’ve been handling most of the randomness and building kind of “probabilistic” objects myself and I feel like this should be instead dealt with Pyro primitives and distribibutions, etc, but at the same time it’s unclear to me how to achieve fairly simplistic tasks with Pyro, like sampling toy data from the SSM once the model
function has been defined.
Thanks in advance!