Is this of any help to you?
trace = pyro.poutine.trace(SCM(vae, mu, sigma)).get_trace()
trace.compute_log_prob() # optional, but allows printing of log_prob shapes
print(trace.format_shapes())
Output:
Trace Shapes:
Param Sites:
Sample Sites:
N_X dist 4096 |
value 4096 |
log_prob 4096 |
N_Y_1 dist 3 |
value 3 |
log_prob 3 |
N_Y_2 dist 6 |
value 6 |
log_prob 6 |
N_Y_3 dist 40 |
value 40 |
log_prob 40 |
N_Y_4 dist 32 |
value 32 |
log_prob 32 |
N_Y_5 dist 32 |
value 32 |
log_prob 32 |
N_Z dist 50 |
value 50 |
log_prob 50 |
Z dist 1 50 |
value 1 50 |
log_prob 1 50 |
Y_1 dist |
value |
log_prob |
Y_2 dist |
value |
log_prob |
Y_3 dist |
value |
log_prob |
Y_4 dist |
value |
log_prob |
Y_5 dist |
value |
log_prob |
X dist 1 4096 |
value 1 4096 |
log_prob 1 4096 |