ValueError: at site "N_X", invalid log_prob shape Expected [], actual [4096]

Is this of any help to you?

trace = pyro.poutine.trace(SCM(vae, mu, sigma)).get_trace()

trace.compute_log_prob() # optional, but allows printing of log_prob shapes

print(trace.format_shapes())

Output:

Trace Shapes:         
 Param Sites:         
Sample Sites:         
     N_X dist   4096 |
        value   4096 |
     log_prob   4096 |
   N_Y_1 dist      3 |
        value      3 |
     log_prob      3 |
   N_Y_2 dist      6 |
        value      6 |
     log_prob      6 |
   N_Y_3 dist     40 |
        value     40 |
     log_prob     40 |
   N_Y_4 dist     32 |
        value     32 |
     log_prob     32 |
   N_Y_5 dist     32 |
        value     32 |
     log_prob     32 |
     N_Z dist     50 |
        value     50 |
     log_prob     50 |
       Z dist 1   50 |
        value 1   50 |
     log_prob 1   50 |
     Y_1 dist        |
        value        |
     log_prob        |
     Y_2 dist        |
        value        |
     log_prob        |
     Y_3 dist        |
        value        |
     log_prob        |
     Y_4 dist        |
        value        |
     log_prob        |
     Y_5 dist        |
        value        |
     log_prob        |
       X dist 1 4096 |
        value 1 4096 |
     log_prob 1 4096 |