Numpyro + Flax?

Hi!
I am just writing for a quick doubt . Flax is currently on pre-release state but hopefully soon to be fully released. Jax is on version 1.41 currently, (and Flax is only properly working with the latest version of Jax) while Numpyro right now is working with 1.37 version. Do you have an estimate of when everything could be working together, updated? Just wondering :blush:, I am aware is a lot of work.

Thanks and take care! :slight_smile:

@artistworking - We should be able to update NumPyro soon so that it works with JAX 1.41. Unless there are some breaking changes in JAX, it might already work but Iā€™ll need to check. Beyond that, I donā€™t suppose we need to do anything else to support libraries like Flax. Is there anything in particular that you are looking to do with Flax and NumPyro?

Hi!

Thanks ! :slight_smile: I am working on a generative model that involves HMC(NUTS) and RNN/GRU/LSTM .

I think the easiest way is to write a potential_fn, which takes inputs are your ā€œrequired-to-sampleā€ parameters. If you want to add Normal() prior for a parameter ā€œweightā€, just simply add Normal().log_prob(weight) to the joint density. You can also write helpers to make this job easier for you.

Currently, NumPyro supports Stax modules to optimize parameters using SVI. To get inference about parameters of nn modules in MCMC, it is better to use potential_fn as above (this applies for Stax, Flax, Haiku,ā€¦)

Whoao, ok. I think I have some reading to do them :slight_smile: . Letā€™s see what I manage to do. I am implementing the model in pyro first, to see if the idea works and because I have more experience with it and then I will transfer to numpyro. Thanks again!!!

@artistworking FYI, Tuan Nguyen has made a great contribution on this integration. In the master branch, you will be able to use flax module in SVI to optimize parameters. In a follow-up PR, you will be able to turn a nn to a bayesian nn to run MCMC to get samples of some neural networkā€™s parameters. :slight_smile:

1 Like

Sorry for the late reply,I missed this message for some reason, but thank you very much! I definitely need flax for quite some stuff. I implemented my own GRU but did not seem very fast.

FYI, you can use random_flax_module to set priors for parameters of your neural network. We just tested it for some dense layers so if you observed any problem with GRU, please let us know. Thanks! :slight_smile:

Ok, thanks! I will definitely have to do that at some point soon :slight_smile: . Thank you so much

Hi again @fehiepsi ! Just letting you know that the random_flax_module seems to cooperate with the GRU/LSTM from Flax during the training phase. However, I have a doubt on how to use the learnt parameters from the GRU to generate samples? I havenā€™t been able to figure it out

This is my training class

class combinerRNN(nn.Module):

def apply(self,children,hidden_dim):
    rng = random.PRNGKey(0)
    with nn.stochastic(rng):  
        carry = nn.GRUCell.initialize_carry(nn.make_rng(), (children.shape[0],), hidden_dim)

    _, logits = flax.jax_utils.scan_in_dim(
        nn.GRUCell.partial(name='gru_cell'), carry, children, axis=1)
    logits = logits[:,-1]
    logits = nn.Dense(logits, max_seq_len*aa_prob, bias=True, name='output')
    logits = nn.log_softmax(logits)

    return logits.reshape(children.shape[0],max_seq_len,aa_prob)

which is called within the model (a bunch of times) as:

module = combinerRNN.partial(hidden_dim=30)
net = random_flax_module("nn_{}".format(int(current_ancestor)), module, prior=dist.Normal(0,1),input_shape=children_samples.shape) 
logits = net(children= children_samples)

I have the learnt parameters in here:

net_params = svi_params[ā€œnn_{}$paramsā€.format(int(current_ancestor))]

I need to use those learnt parameters to obtain the logits and then use them for sampling

Sorry if itā€™s a very obvious question and thanks :slight_smile: !

Hi @artistworking, I guess by sampling, you mean making predictions? If so, you can use the Predictive class as in the example at random_flax_module? If you only want to get logits, you can add a deterministic site to record its value in the output of Predictive.

Edit: I think I was wrong. Your net is a bayesian nn, so net_params should contain empty arrays. Could you check again? If so, you can add numpyro.deterministic("nn_params", net.args[0]) to the model to get ā€œsampledā€ values of your moduleā€™s parameters. But if you only need to get logits, then you can just use Predictive as above.

Hi! Thanks for the quick reply. I had an intuition it could be possible to do with the Predictive class, but I was trying to make ā€œindividual predictionsā€, so with only some parts from the model. I am using a tree and the model function iterates through out all the tree. When I make the predictions , it woul be great to be able to predict the logits for only one node and then sample just for that nodeā€™s logits, instead of all of them. But itā€™s ok now, I have implemented the predictions with the Predictive. The net_params seem empty indeedā€¦ well, they only show things like weights = ShapeArray(3,30) , no valuesā€¦but itā€™s jax, so idk (I could not run it without jit yet, because of some other error when disabling it). Thanks again ! :slight_smile:

Hi @artistworking, glad that you can make it work now.

I could not run it without jit yet

I would like to chat with you to understand more the issue. Do you have sometime to discuss this weekend or next week? Thanks! (my email: fehiepsi at gmail)

email sent :slight_smile:

1 Like