Hi again @fehiepsi ! Just letting you know that the random_flax_module seems to cooperate with the GRU/LSTM from Flax during the training phase. However, I have a doubt on how to use the learnt parameters from the GRU to generate samples? I havenāt been able to figure it out
This is my training class
class combinerRNN(nn.Module):
def apply(self,children,hidden_dim):
rng = random.PRNGKey(0)
with nn.stochastic(rng):
carry = nn.GRUCell.initialize_carry(nn.make_rng(), (children.shape[0],), hidden_dim)
_, logits = flax.jax_utils.scan_in_dim(
nn.GRUCell.partial(name='gru_cell'), carry, children, axis=1)
logits = logits[:,-1]
logits = nn.Dense(logits, max_seq_len*aa_prob, bias=True, name='output')
logits = nn.log_softmax(logits)
return logits.reshape(children.shape[0],max_seq_len,aa_prob)
which is called within the model (a bunch of times) as:
module = combinerRNN.partial(hidden_dim=30)
net = random_flax_module("nn_{}".format(int(current_ancestor)), module, prior=dist.Normal(0,1),input_shape=children_samples.shape)
logits = net(children= children_samples)
I have the learnt parameters in here:
net_params = svi_params[ānn_{}$paramsā.format(int(current_ancestor))]
I need to use those learnt parameters to obtain the logits and then use them for sampling
Sorry if itās a very obvious question and thanks
!