How to use pytroch embedding with NumPyro?

Assume I have 100 sequences, which is [[“A”,“C”,“R”],[“D”,“E”,“F”]…]. Each 10 sequences will be combined to a group. And each group will be converted into a embedding by a pytorch model. What I want is sample as much diverse group as possible by using DiscreteHMCGibbs in NumPyro?
My pseudo code looks like:

def torch_embedding_function(torch_model, sequence_group):
       return torch_model(sequence_group) # shape is (1, 128)

def cosine_similarity(x1, x2):
    
    norm_x1 = x1 / jnp.linalg.norm(x1, axis=-1, keepdims=True)
    norm_x2 = x2 / jnp.linalg.norm(x2, axis=-1, keepdims=True)
    
    return jnp.sum(norm_x1 * norm_x2, axis=-1).item()

def potential_function(projected_group1, projected_group2):
    similarity = cosine_similarity(projected_group1, projected_group2)
    return -similarity  # Return negative similarity as the penalty

def propose_indices(n_seqs, group_size):
    
    probs = jnp.ones(n_seqs) / n_seqs

    sampled_indices = numpyro.sample('sampled_indices', dist.Categorical(probs=probs), sample_shape=(group_size,))
    
    return sampled_indices

def model(group_size, init_projected_group):
   
    sampled_indices = propose_indices(n_seqs, group_size)
    
    projected_msa = torch_embedding_function(torch_model, sampled_indices):
    
    penalty = potential_function(projected_group, init_projected_group)
    
    numpyro.factor('penalty', -penalty)  # Note the negative sign here
    
    return sampled_indices
    

nuts_kernel = NUTS(model)
gibbs_kernel = DiscreteHMCGibbs(nuts_kernel)
mcmc = MCMC(gibbs_kernel, num_warmup=500, num_samples=1000)
mcmc.run(
    random.PRNGKey(0),
   group_size, 
  init_projected_group
)

samples = mcmc.get_samples()
print("Sampled indices:", samples['sampled_indices'])

My problem is :

  1. whether my code is make sense? what I want is to sample different group as diverse as possible .
  2. I always got the error: jax.errors.TracerArrayConversionError: The numpy.ndarray conversion method array() was called on traced array with shape int32[10].
    The error occurred while tracing the function while_body_fun at /xxxxxxxxxxxxxxxxx/python3.12/site-packages/jax/_src/lax/control_flow/loops.py:1999 for while_loop. This concrete value was not available in Python because it depends on the value of the argument loop_carry[0].
    See Errors — JAX documentation

Can someone educate me how to debug this?
Thanks in advance!!