Have you tried poutine.mask. You can generate mask for one part and then inverse it for the other. You can generate the mask based on the realisation of T.
You are right, but when you do inference you will provide a random generator key from JAX and this will trigger samples. Here is a small example:
import numpyro
import numpyro.distributions as dist
from jax import random
from jax import random
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.distributions import constraints
from numpyro.infer import Predictive, SVI, Trace_ELBO
rng_key = random.PRNGKey(0)
def model(vars_dict, goals, validate_args):
T = numpyro.sample("T", dist.Uniform(0, len(goals)))
print(T)
def guide(vars_dict, goals, validate_args):
pass
data = jnp.concatenate([jnp.ones(6), jnp.zeros(4)])
optimizer = numpyro.optim.Adam(step_size=0.0005)
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
svi_result = svi.run(random.PRNGKey(0), 1, {}, data, None)
This will print the value of T, which means that you can use it when making your mask.
Furthermore, do not use numpy focus on jax.numpy. Check the sharp bits here.
In my example above I use SVI but the same is valid for MCMC.
Beware that T is float so indexing should be done after conversion.