Assume I have 100 sequences, which is [[“A”,“C”,“R”],[“D”,“E”,“F”]…]. Each 10 sequences will be combined to a group. And each group will be converted into a embedding by a pytorch model. What I want is sample as much diverse group as possible by using DiscreteHMCGibbs in NumPyro?
My pseudo code looks like:
def torch_embedding_function(torch_model, sequence_group):
return torch_model(sequence_group) # shape is (1, 128)
def cosine_similarity(x1, x2):
norm_x1 = x1 / jnp.linalg.norm(x1, axis=-1, keepdims=True)
norm_x2 = x2 / jnp.linalg.norm(x2, axis=-1, keepdims=True)
return jnp.sum(norm_x1 * norm_x2, axis=-1).item()
def potential_function(projected_group1, projected_group2):
similarity = cosine_similarity(projected_group1, projected_group2)
return -similarity # Return negative similarity as the penalty
def propose_indices(n_seqs, group_size):
probs = jnp.ones(n_seqs) / n_seqs
sampled_indices = numpyro.sample('sampled_indices', dist.Categorical(probs=probs), sample_shape=(group_size,))
return sampled_indices
def model(group_size, init_projected_group):
sampled_indices = propose_indices(n_seqs, group_size)
projected_msa = torch_embedding_function(torch_model, sampled_indices):
penalty = potential_function(projected_group, init_projected_group)
numpyro.factor('penalty', -penalty) # Note the negative sign here
return sampled_indices
nuts_kernel = NUTS(model)
gibbs_kernel = DiscreteHMCGibbs(nuts_kernel)
mcmc = MCMC(gibbs_kernel, num_warmup=500, num_samples=1000)
mcmc.run(
random.PRNGKey(0),
group_size,
init_projected_group
)
samples = mcmc.get_samples()
print("Sampled indices:", samples['sampled_indices'])
My problem is :
- whether my code is make sense? what I want is to sample different group as diverse as possible .
- I always got the error: jax.errors.TracerArrayConversionError: The numpy.ndarray conversion method array() was called on traced array with shape int32[10].
The error occurred while tracing the function while_body_fun at /xxxxxxxxxxxxxxxxx/python3.12/site-packages/jax/_src/lax/control_flow/loops.py:1999 for while_loop. This concrete value was not available in Python because it depends on the value of the argument loop_carry[0].
See Errors — JAX documentation
Can someone educate me how to debug this?
Thanks in advance!!