Hi everyone,
I’m new to pyro and I am trying to understand how I can use the previous output of a distribution to my next node.
Essentially, I want to ‘stack’ outputs such as:
vs = []
for i in range (n_nodes):
v = pyro.sample(f"v{i}", dist.Bernoulli(torch.tensor([0.3])
vs.append(v)
This works but the issue is that ‘vs’ is a list of pyro objects and I cannot apply pyro functions afterwards. Therefore, I struggle to understand how I could build such a concatenation of outputs and use the values without breaking my graph dependencies.
For example doing something like this:
vs = torch.zeros((n_nodes))
for i in range (n_nodes):
v = pyro.sample(f"v{i}", dist.Bernoulli(torch.tensor([0.3])
vs[i] = v
breaks the flow of my graph as ‘vs’ is a torch.Tensor and not a pyro object anymore.
Any help welcomed!