Hi. I am working with a longitudinalprobabilistic model which characterises some dependencies with a copula. I have created a simplified version below:
def test(T=1):
Z1 = numpyro.sample('Z1', dist.Normal(0., 1.))
Y1 = numpyro.sample('Z1', dist.Normal(0., 1.))
q_Z1 = numpyro.deterministic('qZ1', dist.Normal(0., 1.).cdf(Z1))
q_Y1 = numpyro.deterministic('qZ1', dist.Normal(0., 1.).cdf(T1))
cop_y1z1 = numpyro.factor(
'Y1Z1',
copula_lpdfs.bivariate_gaussian_copula_lpdf(
{'Y1': q_Y1, 'Z1': q_Z1},
{'_': rho}
)
)
def transition_fn(state, new):
Z_prev, t = state
Zt = numpyro.sample('Zt', dist.Normal(Z_prev, 1))
q_Zt = numpyro.deterministic('qZt', kernel_cdfs[f"Zt{t}"].cdf(Zt))
Yt = numpyro.sample('Yt', dist.Normal(0, 1.))
q_Yt = numpyro.deterministic('qYt', dist.Normal(0., 1.).cdf(Yt))
cop_ytzt = numpyro.factor(
'YtZt',
copula_lpdfs.bivariate_gaussian_copula_lpdf(
{'Yt': q_Yt, 'Zt': q_Zt},
{'_': rho}
)
)
return (Zt, t + 1), None
scan(transition_fn, (Z1, 1), xs=None, length=T)
I have a set of latent states Z whose evolution is charactered by a conditional dependence Z_2 ~|~Z_1. However, the dependencies of Y_t and Z_t are modeled via a copula on the marginal CDF F_{Z_t}. I have calculated empirical solutions for these CDFs using jax compatible KDE functions and have these values stored in a dictionary kernel_cdfs
. I would like to index the appropriate function within the transition function in the scan()
section, but am unable to do this in the current format as the time index t
is a TracedArray
, not an integer.
Are there any suggestions on how I may proceed? Can I pull out the underlying value of t
for indexing the dictionary or will I need to change my approach entirely?
Thanks