Hi folks,
I’m trying to convert some stan prototype (stan code here) to numpyro. The prototype was tested to run pretty well with stan but I will test it with jax and numpyro to see if I can further enhance performance in scalable solution. So far I feel there are two main challenges to me:
- the for loop within stan
- user specified likelihood
It seems to me 1. is doable with the scan function. However, when it combines to 2. , I’m a bit loss. I follow some previous posts and looks like I should create my own sampler. But I also wonder if I should create my own distribution instead. If that’s the case, is there a tutorial to teach the minimal work to inherit a Distribution class to make this work?
Thanks for the attention!
Edwin
1 Like
You can use factor for custom likelihood. It should work withscan
Thanks for the suggestion. Looks like i can follow
def model():
# ...
if numpyro.get_mask() is not False:
log_density = my_expensive_computation()
numpyro.factor("foo", log_density)
# ...
where my_expensive_computation
can take some latent variables from other sampling block and output additional log_density
? Am I right?
So I come up with this:
def local_level(y, a1, p1):
obs_sigma = numpyro.sample(
"obs_sigma", dist.TruncatedNormal(low=1e-10, loc=1.0, scale=1.0)
)
state_sigma = numpyro.sample(
"state_sigma", dist.TruncatedNormal(low=1e-10, loc=1.0, scale=1.0)
)
a = a1
p = p1
log_prob = jnp.zeros(1)
for t in range(n_steps):
vt = y[t] - a
Ft = p + jnp.square(obs_sigma)
a = a + p * vt / Ft
p = p * (1 - p) / Ft + jnp.square(state_sigma)
log_prob += -0.5 * jnp.log(jnp.fabs(Ft) + jnp.square(vt) / Ft)
numpyro.factor("kalman_filter_ll", log_prob)
kernel = NUTS(local_level)
mcmc = MCMC(kernel, num_warmup=3, num_samples=1, num_chains=num_chains, chain_method='parallel')
a1 = jnp.zeros(1)
p1 = jnp.ones(1)
mcmc.run(random.PRNGKey(0), y, a1=a1, p1=p1)
The run was stuck somewhere after the first (validation?) run somehow. Am I following this correctly?
Could you use scan instead of for loop? Could you print out shapes of some variables in the loop to see if they match your expectation?
1 Like
Thanks for the prompt response @fehiepsi . Somehow replacing the for loop with scan solves the problem. My thought was having a for loop was slow but could be an affordable approach to start a prototype. Somehow I was wrong. With scan, the model runs super fast! I guess the lesson is to always use scan? And your suggestion in testing my transition function with scan first is a great idea! Thank you!