Indexing external dictionary in dynamic model

Hi. I am working with a longitudinalprobabilistic model which characterises some dependencies with a copula. I have created a simplified version below:

def test(T=1):
    Z1 = numpyro.sample('Z1', dist.Normal(0., 1.))
    Y1 = numpyro.sample('Z1', dist.Normal(0., 1.))
    q_Z1 = numpyro.deterministic('qZ1', dist.Normal(0., 1.).cdf(Z1))  
    q_Y1 = numpyro.deterministic('qZ1', dist.Normal(0., 1.).cdf(T1))  

    cop_y1z1 = numpyro.factor(
            {'Y1': q_Y1, 'Z1': q_Z1},
            {'_': rho}

    def transition_fn(state, new):
        Z_prev, t = state
        Zt = numpyro.sample('Zt', dist.Normal(Z_prev, 1))
        q_Zt = numpyro.deterministic('qZt', kernel_cdfs[f"Zt{t}"].cdf(Zt))  

        Yt = numpyro.sample('Yt', dist.Normal(0, 1.))
        q_Yt = numpyro.deterministic('qYt', dist.Normal(0., 1.).cdf(Yt))

        cop_ytzt = numpyro.factor(
                {'Yt': q_Yt, 'Zt': q_Zt},
                {'_': rho}
        return (Zt, t + 1), None
    scan(transition_fn, (Z1, 1), xs=None, length=T)

I have a set of latent states Z whose evolution is charactered by a conditional dependence Z_2 ~|~Z_1. However, the dependencies of Y_t and Z_t are modeled via a copula on the marginal CDF F_{Z_t}. I have calculated empirical solutions for these CDFs using jax compatible KDE functions and have these values stored in a dictionary kernel_cdfs. I would like to index the appropriate function within the transition function in the scan() section, but am unable to do this in the current format as the time index t is a TracedArray, not an integer.

Are there any suggestions on how I may proceed? Can I pull out the underlying value of t for indexing the dictionary or will I need to change my approach entirely?


You just need to change your kernel dictionary to a function depends on array-value t (rather than string-value t).

Thanks @fehiepsi

I’m not sure if I followed you completely, but I tried the following. I created a function to calculate the kernel object from an input set of samples:

def construct_kernel(var, t, samples):
    return GaussianKDEJax(samples[var][:, t])
q_Zt = numpyro.deterministic('qZt', construct_kernel('Zt', t, marg_samples).cdf(Zt))            

Such that I calculate the kernel function at every instant. However, this takes quite a while to calculate and takes around 55% longer to sample from than if we simply use precalculated kernel values.

I’m certain I missed something in your suggestion, but would you be able to elaborate further on how I could construct this function such that I don’t need to recalculate the kernel density function at each iteration?

Never mind… managed to sort it by using:

q_Zt = numpyro.deterministic('qZt', jax.lax.switch(t, kde_z1s, Zt))    

to call the kernel term indexed within kde_z1s. lax.switch came in clutch here.

1 Like

Assuming that you have a pytree of kernel arrays (precomputed), you can select an element at time step t using

kernel = jax.tree_util.tree_map(lambda x: x[t], kernels)
1 Like