SVI converges in complex discrete model but infer_discrete results are nonsense

Thanks for the detailed responses above.

Just debugging the model dimensions, I thought I was clear on what the dimensions meant, but it seems like actually maybe not. So with model_0 of the hmm example, how do the dist and value sites work out to be the values that they are for x/y? I know that dist is the dimensions of site[“fn”] (which is the distribution) and value is the dimensions of the output value, so how do we lose the enumeration dimension in value? I guess it is because y is observed so the enumeration part is collapsed. And then we collapse the potential 1 (that exists to the left of the tone plate) to remove any dimensions of size 1 to the left of the tones plate dimension in y. Since the plates basically post modify the dimensions of sample sites.

The enumeration works the opposite way for x alternately. Would that be a fair reflection? You basically have alternating dimensions in the markov block so that you carry the previous state but can collapse/sum out any of the preceding time states. This does mean with three custom initial states I end up carrying three extra enumeration dimensions through the entire markov block.

The fact that the batch dimension is retained in y and not in x for the value is presumably due to the observation of y?

yes observed discrete variables do not need to be enumerated.

basically the enumeration dimensions are summed out greedily (as soon as possible) so that they do not grow without bound. see e.g. this comment for model0:

Notice that enumeration (over 16 states) alternates between two dimensions:
-2 and -3. If we had not used pyro.markov above, each enumerated variable
would need its own enumeration dimension.

Thanks :smiley: I had seen that diagram/statement and thought that it’s a good job given the dimensional limits :wink:

I have debugged the dimensions and they look correct, and the model and the numbers seem to feed through as expected. Any other tips for common mistakes/thoughts on where things might be going wrong?

If you wouldn’t mind taking a look here is the debug output (I don’t think there is any error in here except maybe in the probs_y1_t and the resulting indexing which I have some misgivings about on reflection; but obviously there is one somewhere):

K = {int} 16
args = {Namespace} Namespace(saved=False, num_steps=50, batch_size=30, sequence_len=104, hidden_dim=16, learning_rate=0.01, file_path=WindowsPath('D:/Work miscellaneous/Downloads/ws300_level3_MultiNB_kmers_5_15_train.csv'), seed=0)
batch = {Tensor: (30,)} tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,\n        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29])
data1 = {Tensor: (30, 52)} tensor([[ 8,  8,  8,  ...,  8,  8,  8],\n        [ 1,  1,  1,  ...,  1,  1,  1],\n        [ 1,  1,  1,  ...,  1,  1,  1],\n        ...,\n        [ 2,  2,  2,  ...,  2,  2,  2],\n        [11, 11, 11,  ..., 11, 11, 11],\n        [ 6,  6,  6,  ...,  6,  6,  6]])
data2 = {Tensor: (30, 52)} tensor([[ 8,  8,  8,  ...,  8,  8,  8],\n        [ 1,  1,  1,  ...,  1,  1,  1],\n        [ 1,  1,  1,  ...,  1,  1,  1],\n        ...,\n        [ 2,  2,  2,  ...,  2,  2,  2],\n        [11, 11, 11,  ..., 11, 11, 11],\n        [ 6,  6,  6,  ...,  6,  6,  6]])
data_dim = {int} 1
epsilonyx = {Tensor: (16, 16)} tensor([[0.6581, 0.4897, 0.3875, 0.1918, 0.8458, 0.1278, 0.7048, 0.3319, 0.2588,\n         0.5898, 0.2403, 0.6152, 0.5982, 0.1288, 0.5832, 0.7130],\n        [0.6979, 0.4371, 0.0901, 0.4229, 0.6737, 0.3176, 0.6898, 0.8330, 0.2389,\n         0.5049, 0.7067, 0.5
lengths = {int} 52
muovtheta = {Tensor: ()} tensor(30.8225, grad_fn=<SumBackward0>)
muy = {Tensor: (16,)} tensor([0.0432, 0.0892, 0.0689, 0.1261, 0.0641, 0.0889, 0.0491, 0.0565, 0.0031,\n        0.0237, 0.0413, 0.0729, 0.0981, 0.1125, 0.0226, 0.0397],\n       grad_fn=<DivBackward0>)
num_sequences = {int} 30
prob_init = {Tensor: (16,)} tensor([0.0149, 0.1650, 0.0505, 0.0636, 0.0403, 0.1764, 0.1661, 0.0204, 0.0018,\n        0.0084, 0.0403, 0.0365, 0.0825, 0.0764, 0.0376, 0.0192],\n       grad_fn=<DivBackward0>)
probs_s = {Tensor: (1,)} tensor([0.0100], grad_fn=<ClampBackward1>)
probs_x = {Tensor: (16, 16)} tensor([[9.1021e-01, 8.3677e-03, 6.4676e-03, 1.1830e-02, 6.0128e-03, 8.3444e-03,\n         4.6042e-03, 5.3013e-03, 2.9463e-04, 2.2284e-03, 3.8784e-03, 6.8428e-03,\n         9.2069e-03, 1.0558e-02, 2.1251e-03, 3.7250e-03],\n        [7.5782e-04, 9.8403e-01, 1.2
probs_x1_t = {Tensor: (16, 16, 2, 1, 16)} tensor([[[[[9.1021e-01, 8.3677e-03, 6.4676e-03,  ..., 1.0558e-02,\n            2.1251e-03, 3.7250e-03]],\n\n          [[9.1021e-01, 8.3677e-03, 6.4676e-03,  ..., 1.0558e-02,\n            2.1251e-03, 3.7250e-03]]],\n\n\n         [[[7.5782e-04, 9.8403e-01, 1.2081e-03,  ..., 1.9721e-03,\n            3.9695e-04, 6.9581e-04]],\n\n          [[9.1021e-01, 8.3677e-03, 6.4676e-03,  ..., 1.0558e-02,\n            2.1251e-03, 3.7250e-03]]],\n\n\n         [[[1.9157e-03, 3.9511e-03, 9.5874e-01,  ..., 4.9851e-03,\n            1.0034e-03, 1.7589e-03]],\n\n          [[9.1021e-01, 8.3677e-03, 6.4676e-03,  ..., 1.0558e-02,\n            2.1251e-03, 3.7250e-03]]],\n\n\n         ...,\n\n\n         [[[2.0654e-03, 4.2601e-03, 3.2927e-03,  ..., 9.5760e-01,\n            1.0819e-03, 1.8964e-03]],\n\n          [[9.1021e-01, 8.3677e-03, 6.4676e-03,  ..., 1.0558e-02,\n            2.1251e-03, 3.7250e-03]]],\n\n\n         [[[8.4507e-04, 1.7430e-03, 1.3472e-03,  ..., 2.1991e-03,\n            9.8089e-01, 7.7593e-04]],\n\n          [[9.1021e-01, 8.3677e...
probs_x2_t = {Tensor: (16, 16, 2, 1, 16)} tensor([[[[[9.1021e-01, 8.3677e-03, 6.4676e-03,  ..., 1.0558e-02,\n            2.1251e-03, 3.7250e-03]],\n\n          [[9.1021e-01, 8.3677e-03, 6.4676e-03,  ..., 1.0558e-02,\n            2.1251e-03, 3.7250e-03]]],\n\n\n         [[[9.1021e-01, 8.3677e-03, 6.4676e-03,  ..., 1.0558e-02,\n            2.1251e-03, 3.7250e-03]],\n\n          [[7.5782e-04, 9.8403e-01, 1.2081e-03,  ..., 1.9721e-03,\n            3.9695e-04, 6.9581e-04]]],\n\n\n         [[[9.1021e-01, 8.3677e-03, 6.4676e-03,  ..., 1.0558e-02,\n            2.1251e-03, 3.7250e-03]],\n\n          [[1.9157e-03, 3.9511e-03, 9.5874e-01,  ..., 4.9851e-03,\n            1.0034e-03, 1.7589e-03]]],\n\n\n         ...,\n\n\n         [[[9.1021e-01, 8.3677e-03, 6.4676e-03,  ..., 1.0558e-02,\n            2.1251e-03, 3.7250e-03]],\n\n          [[2.0654e-03, 4.2601e-03, 3.2927e-03,  ..., 9.5760e-01,\n            1.0819e-03, 1.8964e-03]]],\n\n\n         [[[9.1021e-01, 8.3677e-03, 6.4676e-03,  ..., 1.0558e-02,\n            2.1251e-03, 3.7250e-03]],\n\n          [[8.4507e-04, 1.7430e...
probs_y = {Tensor: (16, 16, 16, 16)} tensor([[[[4.6958e-01, 7.3870e-02, 3.2052e-02,  ..., 2.1780e-02,\n           1.2162e-02, 2.5602e-03],\n          [2.2006e-02, 1.9347e-01, 7.6457e-02,  ..., 5.0959e-02,\n           8.6678e-02, 1.0542e-01],\n          [7.8160e-02, 6.3512e-02, 1.9419e-01,  ..., 4.8891e-02,\n           5.8090e-02, 6.8918e-02],\n          ...,\n          [3.4034e-02, 4.3634e-02, 5.1779e-02,  ..., 1.8184e-01,\n           3.4730e-02, 5.1301e-02],\n          [1.0392e-01, 2.2576e-03, 1.1820e-01,  ..., 9.9569e-02,\n           1.5914e-01, 2.2821e-02],\n          [2.4776e-02, 9.2053e-02, 8.9401e-02,  ..., 7.1076e-02,\n           2.5648e-02, 1.1154e-01]],\n\n         [[1.9399e-01, 1.1225e-01, 4.8705e-02,  ..., 3.3096e-02,\n           1.8481e-02, 3.8905e-03],\n          [3.1744e-01, 1.3503e-01, 5.3361e-02,  ..., 3.5565e-02,\n           6.0494e-02, 7.3577e-02],\n          [7.8160e-02, 6.3512e-02, 1.9419e-01,  ..., 4.8891e-02,\n           5.8090e-02, 6.8918e-02],\n          ...,\n          [3.4034e-02, 4.3634e-02, 5.1779e-02,  ..., 1.8184...
probs_y1_t = {Tensor: (16, 16, 16, 2, 30, 16)} tensor([[[[[[5.0197e-02, 2.9047e-02, 1.2603e-02,  ..., 8.5640e-03,\n             4.7823e-03, 1.0067e-03],\n            [9.4993e-02, 5.6528e-01, 2.3850e-02,  ..., 1.6206e-02,\n             9.0499e-03, 1.9051e-03],\n            [9.4993e-02, 5.6528e-01, 2.3850e-02,  ..., 1.6206e-02,\n             9.0499e-03, 1.9051e-03],\n            ...,\n            [7.5179e-02, 4.3502e-02, 6.3133e-01,  ..., 1.2826e-02,\n             7.1623e-03, 1.5077e-03],\n            [1.1934e-01, 6.9057e-02, 2.9963e-02,  ..., 2.0360e-02,\n             1.1370e-02, 2.3934e-03],\n            [1.3673e-01, 7.9118e-02, 3.4329e-02,  ..., 2.3327e-02,\n             1.3026e-02, 2.7421e-03]],\n\n           [[5.0197e-02, 2.9047e-02, 1.2603e-02,  ..., 8.5640e-03,\n             4.7823e-03, 1.0067e-03],\n            [9.4993e-02, 5.6528e-01, 2.3850e-02,  ..., 1.6206e-02,\n             9.0499e-03, 1.9051e-03],\n            [9.4993e-02, 5.6528e-01, 2.3850e-02,  ..., 1.6206e-02,\n             9.0499e-03, 1.9051e-03],\n            ...,\n            [7.5179...
probs_y2_t = {Tensor: (16, 1, 16, 16, 2, 30, 16)} tensor([[[[[[[5.0197e-02, 2.9047e-02, 1.2603e-02,  ..., 8.5640e-03,\n              4.7823e-03, 1.0067e-03],\n             [9.4993e-02, 5.6528e-01, 2.3850e-02,  ..., 1.6206e-02,\n              9.0499e-03, 1.9051e-03],\n             [9.4993e-02, 5.6528e-01, 2.3850e-02,  ..., 1.6206e-02,\n              9.0499e-03, 1.9051e-03],\n             ...,\n             [7.5179e-02, 4.3502e-02, 6.3133e-01,  ..., 1.2826e-02,\n              7.1623e-03, 1.5077e-03],\n             [1.1934e-01, 6.9057e-02, 2.9963e-02,  ..., 2.0360e-02,\n              1.1370e-02, 2.3934e-03],\n             [1.3673e-01, 7.9118e-02, 3.4329e-02,  ..., 2.3327e-02,\n              1.3026e-02, 2.7421e-03]],\n\n            [[5.0197e-02, 2.9047e-02, 1.2603e-02,  ..., 8.5640e-03,\n              4.7823e-03, 1.0067e-03],\n             [9.4993e-02, 5.6528e-01, 2.3850e-02,  ..., 1.6206e-02,\n              9.0499e-03, 1.9051e-03],\n             [9.4993e-02, 5.6528e-01, 2.3850e-02,  ..., 1.6206e-02,\n              9.0499e-03, 1.9051e-03],\n             ...,...
s = {Tensor: (2, 1, 1, 1, 1, 1, 1)} tensor([[[[[[[0.]]]]]],\n\n\n\n\n\n        [[[[[[1.]]]]]]])
sequences = {Tensor: (30, 104, 1)} tensor([[[ 8],\n         [ 8],\n         [ 8],\n         ...,\n         [ 8],\n         [ 8],\n         [ 8]],\n\n        [[ 1],\n         [ 1],\n         [ 1],\n         ...,\n         [ 1],\n         [ 1],\n         [ 1]],\n\n        [[ 1],\n         [ 1],\n         [ 1],\n         ...,\n         [ 1],\n         [ 1],\n         [ 1]],\n\n        ...,\n\n        [[ 2],\n         [ 2],\n         [ 2],\n         ...,\n         [ 2],\n         [ 2],\n         [ 2]],\n\n        [[11],\n         [11],\n         [11],\n         ...,\n         [11],\n         [11],\n         [11]],\n\n        [[ 6],\n         [ 6],\n         [ 6],\n         ...,\n         [ 6],\n         [ 6],\n         [ 6]]])
sigma = {Tensor: (1,)} tensor([0.0100], grad_fn=<ClampBackward1>)
t = {int} 1
thetay = {Tensor: (16,)} tensor([0.0938, 0.0175, 0.0443, 0.0643, 0.0516, 0.0164, 0.0096, 0.0899, 0.0581,\n        0.0915, 0.0332, 0.0647, 0.0386, 0.0478, 0.0195, 0.0669],\n       grad_fn=<ClampBackward1>)
tmuxgivy = {Tensor: (16, 16)} tensor([[1.9399e-01, 1.1225e-01, 4.8705e-02, 1.0722e-01, 5.1441e-02, 6.7815e-02,\n         1.1685e-01, 4.4357e-03, 2.2719e-02, 4.5800e-02, 3.7421e-02, 1.1431e-01,\n         2.1576e-02, 3.3096e-02, 1.8481e-02, 3.8905e-03],\n        [2.2006e-02, 1.9347e-01, 7.6457e-02, 7.8490e-02, 5.5647e-02, 2.5763e-02,\n         6.1811e-02, 3.5054e-03, 1.4667e-02, 2.5612e-02, 8.6223e-02, 8.3864e-02,\n         2.9421e-02, 5.0959e-02, 8.6678e-02, 1.0542e-01],\n        [7.8160e-02, 6.3512e-02, 1.9419e-01, 2.3008e-02, 6.6380e-02, 1.2572e-02,\n         1.7173e-02, 2.7049e-02, 8.1271e-02, 7.8455e-02, 2.2809e-02, 7.2857e-02,\n         8.6670e-02, 4.8891e-02, 5.8090e-02, 6.8918e-02],\n        [8.2116e-02, 9.9337e-02, 1.1624e-02, 1.2332e-01, 7.0594e-02, 9.2666e-02,\n         9.4777e-02, 9.5392e-02, 6.0763e-02, 6.6092e-03, 5.5339e-02, 1.8973e-02,\n         3.4484e-03, 9.5703e-02, 8.9210e-02, 1.2528e-04],\n        [7.1237e-02, 4.9897e-02, 5.0131e-02, 3.2538e-02, 1.9109e-01, 2.4464e-02,\n         8.2004e-02, 9.0351e-02, 1.0296...
x1 = {Tensor: (16, 1, 1, 1, 1)} tensor([[[[[ 0]]]],\n\n\n\n        [[[[ 1]]]],\n\n\n\n        [[[[ 2]]]],\n\n\n\n        [[[[ 3]]]],\n\n\n\n        [[[[ 4]]]],\n\n\n\n        [[[[ 5]]]],\n\n\n\n        [[[[ 6]]]],\n\n\n\n        [[[[ 7]]]],\n\n\n\n        [[[[ 8]]]],\n\n\n\n        [[[[ 9]]]],\n\n\n\n        [[[[10]]]],\n\n\n\n        [[[[11]]]],\n\n\n\n        [[[[12]]]],\n\n\n\n        [[[[13]]]],\n\n\n\n        [[[[14]]]],\n\n\n\n        [[[[15]]]]])
x1_help = {Tensor: (16, 16, 2, 1)} tensor([[[[ 0],\n          [ 0]],\n\n         [[ 1],\n          [ 0]],\n\n         [[ 2],\n          [ 0]],\n\n         [[ 3],\n          [ 0]],\n\n         [[ 4],\n          [ 0]],\n\n         [[ 5],\n          [ 0]],\n\n         [[ 6],\n          [ 0]],\n\n         [[ 7],\n          [ 0]],\n\n         [[ 8],\n          [ 0]],\n\n         [[ 9],\n          [ 0]],\n\n         [[10],\n          [ 0]],\n\n         [[11],\n          [ 0]],\n\n         [[12],\n          [ 0]],\n\n         [[13],\n          [ 0]],\n\n         [[14],\n          [ 0]],\n\n         [[15],\n          [ 0]]],\n\n\n        [[[ 0],\n          [ 1]],\n\n         [[ 1],\n          [ 1]],\n\n         [[ 2],\n          [ 1]],\n\n         [[ 3],\n          [ 1]],\n\n         [[ 4],\n          [ 1]],\n\n         [[ 5],\n          [ 1]],\n\n         [[ 6],\n          [ 1]],\n\n         [[ 7],\n          [ 1]],\n\n         [[ 8],\n          [ 1]],\n\n         [[ 9],\n          [ 1]],\n\n         [[10],\n          [ 1]],\n\n         [[11],\n          [ 1]],\n\n         [[12],\n          [ 1]],\n\n         [[1...
x1_list = {list: 1} [tensor([[[ 0]],\n\n        [[ 1]],\n\n        [[ 2]],\n\n        [[ 3]],\n\n        [[ 4]],\n\n        [[ 5]],\n\n        [[ 6]],\n\n        [[ 7]],\n\n        [[ 8]],\n\n        [[ 9]],\n\n        [[10]],\n\n        [[11]],\n\n        [[12]],\n\n        [[13]],\n\n        [[14]],\n\n        [[15]]])]
x2 = {Tensor: (16, 1, 1, 1, 1, 1)} tensor([[[[[[ 0]]]]],\n\n\n\n\n        [[[[[ 1]]]]],\n\n\n\n\n        [[[[[ 2]]]]],\n\n\n\n\n        [[[[[ 3]]]]],\n\n\n\n\n        [[[[[ 4]]]]],\n\n\n\n\n        [[[[[ 5]]]]],\n\n\n\n\n        [[[[[ 6]]]]],\n\n\n\n\n        [[[[[ 7]]]]],\n\n\n\n\n        [[[[[ 8]]]]],\n\n\n\n\n        [[[[[ 9]]]]],\n\n\n\n\n        [[[[[10]]]]],\n\n\n\n\n        [[[[[11]]]]],\n\n\n\n\n        [[[[[12]]]]],\n\n\n\n\n        [[[[[13]]]]],\n\n\n\n\n        [[[[[14]]]]],\n\n\n\n\n        [[[[[15]]]]]])
x2_help = {Tensor: (16, 16, 2, 1)} tensor([[[[ 0],\n          [ 0]],\n\n         [[ 0],\n          [ 1]],\n\n         [[ 0],\n          [ 2]],\n\n         [[ 0],\n          [ 3]],\n\n         [[ 0],\n          [ 4]],\n\n         [[ 0],\n          [ 5]],\n\n         [[ 0],\n          [ 6]],\n\n         [[ 0],\n          [ 7]],\n\n         [[ 0],\n          [ 8]],\n\n         [[ 0],\n          [ 9]],\n\n         [[ 0],\n          [10]],\n\n         [[ 0],\n          [11]],\n\n         [[ 0],\n          [12]],\n\n         [[ 0],\n          [13]],\n\n         [[ 0],\n          [14]],\n\n         [[ 0],\n          [15]]],\n\n\n        [[[ 1],\n          [ 0]],\n\n         [[ 1],\n          [ 1]],\n\n         [[ 1],\n          [ 2]],\n\n         [[ 1],\n          [ 3]],\n\n         [[ 1],\n          [ 4]],\n\n         [[ 1],\n          [ 5]],\n\n         [[ 1],\n          [ 6]],\n\n         [[ 1],\n          [ 7]],\n\n         [[ 1],\n          [ 8]],\n\n         [[ 1],\n          [ 9]],\n\n         [[ 1],\n          [10]],\n\n         [[ 1],\n          [11]],\n\n         [[ 1],\n          [12]],\n\n         [[ ...
x2_list = {list: 1} [tensor([[[[ 0]]],\n\n\n        [[[ 1]]],\n\n\n        [[[ 2]]],\n\n\n        [[[ 3]]],\n\n\n        [[[ 4]]],\n\n\n        [[[ 5]]],\n\n\n        [[[ 6]]],\n\n\n        [[[ 7]]],\n\n\n        [[[ 8]]],\n\n\n        [[[ 9]]],\n\n\n        [[[10]]],\n\n\n        [[[11]]],\n\n\n        [[[12]]],\n\n\n        [[[13]]],\n\n\n        [[[14]]],\n\n\n        [[[15]]]])]
y1 = {Tensor: (30,)} tensor([ 8,  1,  1, 11, 15,  3, 10, 13, 14, 14, 11,  2,  1,  1,  7, 10,  3, 15,\n         7,  8,  1, 13, 15,  7, 12, 12,  6,  2, 11,  6])
y1_help = {Tensor: (2, 30)} tensor([[ 8,  1,  1, 11, 15,  3, 10, 13, 14, 14, 11,  2,  1,  1,  7, 10,  3, 15,\n          7,  8,  1, 13, 15,  7, 12, 12,  6,  2, 11,  6],\n        [ 8,  1,  1, 11, 15,  3, 10, 13, 14, 14, 11,  2,  1,  1,  7, 10,  3, 15,\n          7,  8,  1, 13, 15,  7, 12, 12,  6,  2, 11,  6]])
y2 = {Tensor: (30,)} tensor([ 8,  1,  1, 11, 15,  3, 10, 13, 14, 14, 11,  2,  1,  1,  7, 10,  3, 15,\n         7,  8,  1, 13, 15,  7, 12, 12,  6,  2, 11,  6])
y2_help = {Tensor: (2, 30)} tensor([[ 8,  1,  1, 11, 15,  3, 10, 13, 14, 14, 11,  2,  1,  1,  7, 10,  3, 15,\n          7,  8,  1, 13, 15,  7, 12, 12,  6,  2, 11,  6],\n        [ 8,  1,  1, 11, 15,  3, 10, 13, 14, 14, 11,  2,  1,  1,  7, 10,  3, 15,\n          7,  8,  1, 13, 15,  7, 12, 12,  6,  2, 11,  6]])

I can format the output better to just leave the shapes if that is helpful?

The other “half” of the Markov output being the following:

K = {int} 16
args = {Namespace} Namespace(saved=False, num_steps=50, batch_size=30, sequence_len=104, hidden_dim=16, learning_rate=0.01, file_path=WindowsPath('D:/Work miscellaneous/Downloads/ws300_level3_MultiNB_kmers_5_15_train.csv'), seed=0)
batch = {Tensor: (30,)} tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,\n        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29])
data1 = {Tensor: (30, 52)} tensor([[ 8,  8,  8,  ...,  8,  8,  8],\n        [ 1,  1,  1,  ...,  1,  1,  1],\n        [ 1,  1,  1,  ...,  1,  1,  1],\n        ...,\n        [ 2,  2,  2,  ...,  2,  2,  2],\n        [11, 11, 11,  ..., 11, 11, 11],\n        [ 6,  6,  6,  ...,  6,  6,  6]])
data2 = {Tensor: (30, 52)} tensor([[ 8,  8,  8,  ...,  8,  8,  8],\n        [ 1,  1,  1,  ...,  1,  1,  1],\n        [ 1,  1,  1,  ...,  1,  1,  1],\n        ...,\n        [ 2,  2,  2,  ...,  2,  2,  2],\n        [11, 11, 11,  ..., 11, 11, 11],\n        [ 6,  6,  6,  ...,  6,  6,  6]])
data_dim = {int} 1
epsilonyx = {Tensor: (16, 16)} tensor([[0.6581, 0.4897, 0.3875, 0.1918, 0.8458, 0.1278, 0.7048, 0.3319, 0.2588,\n         0.5898, 0.2403, 0.6152, 0.5982, 0.1288, 0.5832, 0.7130],\n        [0.6979, 0.4371, 0.0901, 0.4229, 0.6737, 0.3176, 0.6898, 0.8330, 0.2389,\n         0.5049, 0.7067, 0.5
lengths = {int} 52
muovtheta = {Tensor: ()} tensor(30.8225, grad_fn=<SumBackward0>)
muy = {Tensor: (16,)} tensor([0.0432, 0.0892, 0.0689, 0.1261, 0.0641, 0.0889, 0.0491, 0.0565, 0.0031,\n        0.0237, 0.0413, 0.0729, 0.0981, 0.1125, 0.0226, 0.0397],\n       grad_fn=<DivBackward0>)
num_sequences = {int} 30
prob_init = {Tensor: (16,)} tensor([0.0149, 0.1650, 0.0505, 0.0636, 0.0403, 0.1764, 0.1661, 0.0204, 0.0018,\n        0.0084, 0.0403, 0.0365, 0.0825, 0.0764, 0.0376, 0.0192],\n       grad_fn=<DivBackward0>)
probs_s = {Tensor: (1,)} tensor([0.0100], grad_fn=<ClampBackward1>)
probs_x = {Tensor: (16, 16)} tensor([[9.1021e-01, 8.3677e-03, 6.4676e-03, 1.1830e-02, 6.0128e-03, 8.3444e-03,\n         4.6042e-03, 5.3013e-03, 2.9463e-04, 2.2284e-03, 3.8784e-03, 6.8428e-03,\n         9.2069e-03, 1.0558e-02, 2.1251e-03, 3.7250e-03],\n        [7.5782e-04, 9.8403e-01, 1.2
probs_x1_t = {Tensor: (2, 16, 16, 1, 1, 1, 1, 16)} tensor([[[[[[[[9.1021e-01, 8.3677e-03, 6.4676e-03,  ..., 1.0558e-02,\n               2.1251e-03, 3.7250e-03]]]]],\n\n\n\n\n          [[[[[7.5782e-04, 9.8403e-01, 1.2081e-03,  ..., 1.9721e-03,\n               3.9695e-04, 6.9581e-04]]]]],\n\n\n\n\n          [[[[[1.9157e-03, 3.9511e-03, 9.5874e-01,  ..., 4.9851e-03,\n               1.0034e-03, 1.7589e-03]]]]],\n\n\n\n\n          ...,\n\n\n\n\n          [[[[[2.0654e-03, 4.2601e-03, 3.2927e-03,  ..., 9.5760e-01,\n               1.0819e-03, 1.8964e-03]]]]],\n\n\n\n\n          [[[[[8.4507e-04, 1.7430e-03, 1.3472e-03,  ..., 2.1991e-03,\n               9.8089e-01, 7.7593e-04]]]]],\n\n\n\n\n          [[[[[2.8926e-03, 5.9662e-03, 4.6114e-03,  ..., 7.5274e-03,\n               1.5152e-03, 9.3575e-01]]]]]],\n\n\n\n\n\n         [[[[[[9.1021e-01, 8.3677e-03, 6.4676e-03,  ..., 1.0558e-02,\n               2.1251e-03, 3.7250e-03]]]]],\n\n\n\n\n          [[[[[7.5782e-04, 9.8403e-01, 1.2081e-03,  ..., 1.9721e-03,\n               3.9695e-04, 6.9581e-04]]]]],\n\n\n\n\n          [[[[[1.9157e-03, 3.9511e-03, 9.58...
probs_x2_t = {Tensor: (2, 16, 16, 1, 1, 1, 1, 16)} tensor([[[[[[[[9.1021e-01, 8.3677e-03, 6.4676e-03,  ..., 1.0558e-02,\n               2.1251e-03, 3.7250e-03]]]]],\n\n\n\n\n          [[[[[9.1021e-01, 8.3677e-03, 6.4676e-03,  ..., 1.0558e-02,\n               2.1251e-03, 3.7250e-03]]]]],\n\n\n\n\n          [[[[[9.1021e-01, 8.3677e-03, 6.4676e-03,  ..., 1.0558e-02,\n               2.1251e-03, 3.7250e-03]]]]],\n\n\n\n\n          ...,\n\n\n\n\n          [[[[[9.1021e-01, 8.3677e-03, 6.4676e-03,  ..., 1.0558e-02,\n               2.1251e-03, 3.7250e-03]]]]],\n\n\n\n\n          [[[[[9.1021e-01, 8.3677e-03, 6.4676e-03,  ..., 1.0558e-02,\n               2.1251e-03, 3.7250e-03]]]]],\n\n\n\n\n          [[[[[9.1021e-01, 8.3677e-03, 6.4676e-03,  ..., 1.0558e-02,\n               2.1251e-03, 3.7250e-03]]]]]],\n\n\n\n\n\n         [[[[[[7.5782e-04, 9.8403e-01, 1.2081e-03,  ..., 1.9721e-03,\n               3.9695e-04, 6.9581e-04]]]]],\n\n\n\n\n          [[[[[7.5782e-04, 9.8403e-01, 1.2081e-03,  ..., 1.9721e-03,\n               3.9695e-04, 6.9581e-04]]]]],\n\n\n\n\n          [[[[[7.5782e-04, 9.8403e-01, 1.20...
probs_y = {Tensor: (16, 16, 16, 16)} tensor([[[[4.6958e-01, 7.3870e-02, 3.2052e-02,  ..., 2.1780e-02,\n           1.2162e-02, 2.5602e-03],\n          [2.2006e-02, 1.9347e-01, 7.6457e-02,  ..., 5.0959e-02,\n           8.6678e-02, 1.0542e-01],\n          [7.8160e-02, 6.3512e-02, 1.9419e-01,  ..., 4.8891e-02,\n           5.8090e-02, 6.8918e-02],\n          ...,\n          [3.4034e-02, 4.3634e-02, 5.1779e-02,  ..., 1.8184e-01,\n           3.4730e-02, 5.1301e-02],\n          [1.0392e-01, 2.2576e-03, 1.1820e-01,  ..., 9.9569e-02,\n           1.5914e-01, 2.2821e-02],\n          [2.4776e-02, 9.2053e-02, 8.9401e-02,  ..., 7.1076e-02,\n           2.5648e-02, 1.1154e-01]],\n\n         [[1.9399e-01, 1.1225e-01, 4.8705e-02,  ..., 3.3096e-02,\n           1.8481e-02, 3.8905e-03],\n          [3.1744e-01, 1.3503e-01, 5.3361e-02,  ..., 3.5565e-02,\n           6.0494e-02, 7.3577e-02],\n          [7.8160e-02, 6.3512e-02, 1.9419e-01,  ..., 4.8891e-02,\n           5.8090e-02, 6.8918e-02],\n          ...,\n          [3.4034e-02, 4.3634e-02, 5.1779e-02,  ..., 1.8184...
probs_y1_t = {Tensor: (16, 2, 16, 16, 1, 1, 1, 30, 16)} tensor([[[[[[[[[5.0197e-02, 2.9047e-02, 1.2603e-02,  ..., 8.5640e-03,\n                4.7823e-03, 1.0067e-03],\n               [9.4993e-02, 5.6528e-01, 2.3850e-02,  ..., 1.6206e-02,\n                9.0499e-03, 1.9051e-03],\n               [9.4993e-02, 5.6528e-01, 2.3850e-02,  ..., 1.6206e-02,\n                9.0499e-03, 1.9051e-03],\n               ...,\n               [7.5179e-02, 4.3502e-02, 6.3133e-01,  ..., 1.2826e-02,\n                7.1623e-03, 1.5077e-03],\n               [1.1934e-01, 6.9057e-02, 2.9963e-02,  ..., 2.0360e-02,\n                1.1370e-02, 2.3934e-03],\n               [1.3673e-01, 7.9118e-02, 3.4329e-02,  ..., 2.3327e-02,\n                1.3026e-02, 2.7421e-03]]]]],\n\n\n\n\n           [[[[[1.9399e-01, 1.1225e-01, 4.8705e-02,  ..., 3.3096e-02,\n                1.8481e-02, 3.8905e-03],\n               [1.9399e-01, 1.1225e-01, 4.8705e-02,  ..., 3.3096e-02,\n                1.8481e-02, 3.8905e-03],\n               [1.9399e-01, 1.1225e-01, 4.8705e-02,  ..., 3.3096e-02,\n              ...
probs_y2_t = {Tensor: (16, 1, 2, 16, 16, 1, 1, 1, 30, 16)} tensor([[[[[[[[[[5.0197e-02, 2.9047e-02, 1.2603e-02,  ..., 8.5640e-03,\n                 4.7823e-03, 1.0067e-03],\n                [9.4993e-02, 5.6528e-01, 2.3850e-02,  ..., 1.6206e-02,\n                 9.0499e-03, 1.9051e-03],\n                [9.4993e-02, 5.6528e-01, 2.3850e-02,  ..., 1.6206e-02,\n                 9.0499e-03, 1.9051e-03],\n                ...,\n                [7.5179e-02, 4.3502e-02, 6.3133e-01,  ..., 1.2826e-02,\n                 7.1623e-03, 1.5077e-03],\n                [1.1934e-01, 6.9057e-02, 2.9963e-02,  ..., 2.0360e-02,\n                 1.1370e-02, 2.3934e-03],\n                [1.3673e-01, 7.9118e-02, 3.4329e-02,  ..., 2.3327e-02,\n                 1.3026e-02, 2.7421e-03]]]]],\n\n\n\n\n            [[[[[5.0197e-02, 2.9047e-02, 1.2603e-02,  ..., 8.5640e-03,\n                 4.7823e-03, 1.0067e-03],\n                [9.4993e-02, 5.6528e-01, 2.3850e-02,  ..., 1.6206e-02,\n                 9.0499e-03, 1.9051e-03],\n                [9.4993e-02, 5.6528e-01, 2.3850e-02,  ..., 1.6206e-...
s = {Tensor: (2, 1, 1, 1, 1, 1, 1, 1, 1, 1)} tensor([[[[[[[[[[0.]]]]]]]]],\n\n\n\n\n\n\n\n\n        [[[[[[[[[1.]]]]]]]]]])
sequences = {Tensor: (30, 104, 1)} tensor([[[ 8],\n         [ 8],\n         [ 8],\n         ...,\n         [ 8],\n         [ 8],\n         [ 8]],\n\n        [[ 1],\n         [ 1],\n         [ 1],\n         ...,\n         [ 1],\n         [ 1],\n         [ 1]],\n\n        [[ 1],\n         [ 1],\n         [ 1],\n         ...,\n         [ 1],\n         [ 1],\n         [ 1]],\n\n        ...,\n\n        [[ 2],\n         [ 2],\n         [ 2],\n         ...,\n         [ 2],\n         [ 2],\n         [ 2]],\n\n        [[11],\n         [11],\n         [11],\n         ...,\n         [11],\n         [11],\n         [11]],\n\n        [[ 6],\n         [ 6],\n         [ 6],\n         ...,\n         [ 6],\n         [ 6],\n         [ 6]]])
sigma = {Tensor: (1,)} tensor([0.0100], grad_fn=<ClampBackward1>)
t = {int} 6
thetay = {Tensor: (16,)} tensor([0.0938, 0.0175, 0.0443, 0.0643, 0.0516, 0.0164, 0.0096, 0.0899, 0.0581,\n        0.0915, 0.0332, 0.0647, 0.0386, 0.0478, 0.0195, 0.0669],\n       grad_fn=<ClampBackward1>)
tmuxgivy = {Tensor: (16, 16)} tensor([[1.9399e-01, 1.1225e-01, 4.8705e-02, 1.0722e-01, 5.1441e-02, 6.7815e-02,\n         1.1685e-01, 4.4357e-03, 2.2719e-02, 4.5800e-02, 3.7421e-02, 1.1431e-01,\n         2.1576e-02, 3.3096e-02, 1.8481e-02, 3.8905e-03],\n        [2.2006e-02, 1.9347e-01, 7.6457e-02, 7.8490e-02, 5.5647e-02, 2.5763e-02,\n         6.1811e-02, 3.5054e-03, 1.4667e-02, 2.5612e-02, 8.6223e-02, 8.3864e-02,\n         2.9421e-02, 5.0959e-02, 8.6678e-02, 1.0542e-01],\n        [7.8160e-02, 6.3512e-02, 1.9419e-01, 2.3008e-02, 6.6380e-02, 1.2572e-02,\n         1.7173e-02, 2.7049e-02, 8.1271e-02, 7.8455e-02, 2.2809e-02, 7.2857e-02,\n         8.6670e-02, 4.8891e-02, 5.8090e-02, 6.8918e-02],\n        [8.2116e-02, 9.9337e-02, 1.1624e-02, 1.2332e-01, 7.0594e-02, 9.2666e-02,\n         9.4777e-02, 9.5392e-02, 6.0763e-02, 6.6092e-03, 5.5339e-02, 1.8973e-02,\n         3.4484e-03, 9.5703e-02, 8.9210e-02, 1.2528e-04],\n        [7.1237e-02, 4.9897e-02, 5.0131e-02, 3.2538e-02, 1.9109e-01, 2.4464e-02,\n         8.2004e-02, 9.0351e-02, 1.0296...
x1 = {Tensor: (16, 1, 1, 1, 1, 1, 1, 1)} tensor([[[[[[[[ 0]]]]]]],\n\n\n\n\n\n\n        [[[[[[[ 1]]]]]]],\n\n\n\n\n\n\n        [[[[[[[ 2]]]]]]],\n\n\n\n\n\n\n        [[[[[[[ 3]]]]]]],\n\n\n\n\n\n\n        [[[[[[[ 4]]]]]]],\n\n\n\n\n\n\n        [[[[[[[ 5]]]]]]],\n\n\n\n\n\n\n        [[[[[[[ 6]]]]]]],\n\n\n\n\n\n\n        [[[[[[[ 7]]]]]]],\n\n\n\n\n\n\n        [[[[[[[ 8]]]]]]],\n\n\n\n\n\n\n        [[[[[[[ 9]]]]]]],\n\n\n\n\n\n\n        [[[[[[[10]]]]]]],\n\n\n\n\n\n\n        [[[[[[[11]]]]]]],\n\n\n\n\n\n\n        [[[[[[[12]]]]]]],\n\n\n\n\n\n\n        [[[[[[[13]]]]]]],\n\n\n\n\n\n\n        [[[[[[[14]]]]]]],\n\n\n\n\n\n\n        [[[[[[[15]]]]]]]])
x1_help = {Tensor: (2, 16, 16, 1, 1, 1, 1)} tensor([[[[[[[ 0]]]],\n\n\n\n          [[[[ 1]]]],\n\n\n\n          [[[[ 2]]]],\n\n\n\n          [[[[ 3]]]],\n\n\n\n          [[[[ 4]]]],\n\n\n\n          [[[[ 5]]]],\n\n\n\n          [[[[ 6]]]],\n\n\n\n          [[[[ 7]]]],\n\n\n\n          [[[[ 8]]]],\n\n\n\n          [[[[ 9]]]],\n\n\n\n          [[[[10]]]],\n\n\n\n          [[[[11]]]],\n\n\n\n          [[[[12]]]],\n\n\n\n          [[[[13]]]],\n\n\n\n          [[[[14]]]],\n\n\n\n          [[[[15]]]]],\n\n\n\n\n         [[[[[ 0]]]],\n\n\n\n          [[[[ 1]]]],\n\n\n\n          [[[[ 2]]]],\n\n\n\n          [[[[ 3]]]],\n\n\n\n          [[[[ 4]]]],\n\n\n\n          [[[[ 5]]]],\n\n\n\n          [[[[ 6]]]],\n\n\n\n          [[[[ 7]]]],\n\n\n\n          [[[[ 8]]]],\n\n\n\n          [[[[ 9]]]],\n\n\n\n          [[[[10]]]],\n\n\n\n          [[[[11]]]],\n\n\n\n          [[[[12]]]],\n\n\n\n          [[[[13]]]],\n\n\n\n          [[[[14]]]],\n\n\n\n          [[[[15]]]]],\n\n\n\n\n         [[[[[ 0]]]],\n\n\n\n          [[[[ 1]]]],\n\n\n\n          [[[[ 2]]]],\n\n\n\n          [[[[ 3]]]],\n\n\n\n          [[[[ 4]]]],\n\n\n\n          [[[[ 5]]]],\n\n\n\n          [[[[ 6]]]],\n\n\n\n          [[[[ 7]]]],...
x1_list = {list: 6} [tensor([[[ 0]],\n\n        [[ 1]],\n\n        [[ 2]],\n\n        [[ 3]],\n\n        [[ 4]],\n\n        [[ 5]],\n\n        [[ 6]],\n\n        [[ 7]],\n\n        [[ 8]],\n\n        [[ 9]],\n\n        [[10]],\n\n        [[11]],\n\n        [[12]],\n\n        [[13]],\n\n        [[14]],\n\n        [[15]]]), tensor([[[[[ 0]]]],\n\n\n\n        [[[[ 1]]]],\n\n\n\n        [[[[ 2]]]],\n\n\n\n        [[[[ 3]]]],\n\n\n\n        [[[[ 4]]]],\n\n\n\n        [[[[ 5]]]],\n\n\n\n        [[[[ 6]]]],\n\n\n\n        [[[[ 7]]]],\n\n\n\n        [[[[ 8]]]],\n\n\n\n        [[[[ 9]]]],\n\n\n\n        [[[[10]]]],\n\n\n\n        [[[[11]]]],\n\n\n\n        [[[[12]]]],\n\n\n\n        [[[[13]]]],\n\n\n\n        [[[[14]]]],\n\n\n\n        [[[[15]]]]]), tensor([[[[[[[[ 0]]]]]]],\n\n\n\n\n\n\n        [[[[[[[ 1]]]]]]],\n\n\n\n\n\n\n        [[[[[[[ 2]]]]]]],\n\n\n\n\n\n\n        [[[[[[[ 3]]]]]]],\n\n\n\n\n\n\n        [[[[[[[ 4]]]]]]],\n\n\n\n\n\n\n        [[[[[[[ 5]]]]]]],\n\n\n\n\n\n\n        [[[[[[[ 6]]]]]]],\n\n\n\n\n\n\n        [[[[[[[ 7]]]]]]],\n\n\n\n\n\n\n        [[[[[[[ 8]]]]]]],\n\n\n\n\n\n\n        [[[[[[[ 9]]]]]]],\n\n\n\n\n\n\n        [[[[[[[10]]]]]]],\n\n\n\n\n\n\n       ...
x2 = {Tensor: (16, 1, 1, 1, 1, 1, 1, 1, 1)} tensor([[[[[[[[[ 0]]]]]]]],\n\n\n\n\n\n\n\n        [[[[[[[[ 1]]]]]]]],\n\n\n\n\n\n\n\n        [[[[[[[[ 2]]]]]]]],\n\n\n\n\n\n\n\n        [[[[[[[[ 3]]]]]]]],\n\n\n\n\n\n\n\n        [[[[[[[[ 4]]]]]]]],\n\n\n\n\n\n\n\n        [[[[[[[[ 5]]]]]]]],\n\n\n\n\n\n\n\n        [[[[[[[[ 6]]]]]]]],\n\n\n\n\n\n\n\n        [[[[[[[[ 7]]]]]]]],\n\n\n\n\n\n\n\n        [[[[[[[[ 8]]]]]]]],\n\n\n\n\n\n\n\n        [[[[[[[[ 9]]]]]]]],\n\n\n\n\n\n\n\n        [[[[[[[[10]]]]]]]],\n\n\n\n\n\n\n\n        [[[[[[[[11]]]]]]]],\n\n\n\n\n\n\n\n        [[[[[[[[12]]]]]]]],\n\n\n\n\n\n\n\n        [[[[[[[[13]]]]]]]],\n\n\n\n\n\n\n\n        [[[[[[[[14]]]]]]]],\n\n\n\n\n\n\n\n        [[[[[[[[15]]]]]]]]])
x2_help = {Tensor: (2, 16, 16, 1, 1, 1, 1)} tensor([[[[[[[ 0]]]],\n\n\n\n          [[[[ 0]]]],\n\n\n\n          [[[[ 0]]]],\n\n\n\n          [[[[ 0]]]],\n\n\n\n          [[[[ 0]]]],\n\n\n\n          [[[[ 0]]]],\n\n\n\n          [[[[ 0]]]],\n\n\n\n          [[[[ 0]]]],\n\n\n\n          [[[[ 0]]]],\n\n\n\n          [[[[ 0]]]],\n\n\n\n          [[[[ 0]]]],\n\n\n\n          [[[[ 0]]]],\n\n\n\n          [[[[ 0]]]],\n\n\n\n          [[[[ 0]]]],\n\n\n\n          [[[[ 0]]]],\n\n\n\n          [[[[ 0]]]]],\n\n\n\n\n         [[[[[ 1]]]],\n\n\n\n          [[[[ 1]]]],\n\n\n\n          [[[[ 1]]]],\n\n\n\n          [[[[ 1]]]],\n\n\n\n          [[[[ 1]]]],\n\n\n\n          [[[[ 1]]]],\n\n\n\n          [[[[ 1]]]],\n\n\n\n          [[[[ 1]]]],\n\n\n\n          [[[[ 1]]]],\n\n\n\n          [[[[ 1]]]],\n\n\n\n          [[[[ 1]]]],\n\n\n\n          [[[[ 1]]]],\n\n\n\n          [[[[ 1]]]],\n\n\n\n          [[[[ 1]]]],\n\n\n\n          [[[[ 1]]]],\n\n\n\n          [[[[ 1]]]]],\n\n\n\n\n         [[[[[ 2]]]],\n\n\n\n          [[[[ 2]]]],\n\n\n\n          [[[[ 2]]]],\n\n\n\n          [[[[ 2]]]],\n\n\n\n          [[[[ 2]]]],\n\n\n\n          [[[[ 2]]]],\n\n\n\n          [[[[ 2]]]],\n\n\n\n          [[[[ 2]]]],...
x2_list = {list: 6} [tensor([[[[ 0]]],\n\n\n        [[[ 1]]],\n\n\n        [[[ 2]]],\n\n\n        [[[ 3]]],\n\n\n        [[[ 4]]],\n\n\n        [[[ 5]]],\n\n\n        [[[ 6]]],\n\n\n        [[[ 7]]],\n\n\n        [[[ 8]]],\n\n\n        [[[ 9]]],\n\n\n        [[[10]]],\n\n\n        [[[11]]],\n\n\n        [[[12]]],\n\n\n        [[[13]]],\n\n\n        [[[14]]],\n\n\n        [[[15]]]]), tensor([[[[[[ 0]]]]],\n\n\n\n\n        [[[[[ 1]]]]],\n\n\n\n\n        [[[[[ 2]]]]],\n\n\n\n\n        [[[[[ 3]]]]],\n\n\n\n\n        [[[[[ 4]]]]],\n\n\n\n\n        [[[[[ 5]]]]],\n\n\n\n\n        [[[[[ 6]]]]],\n\n\n\n\n        [[[[[ 7]]]]],\n\n\n\n\n        [[[[[ 8]]]]],\n\n\n\n\n        [[[[[ 9]]]]],\n\n\n\n\n        [[[[[10]]]]],\n\n\n\n\n        [[[[[11]]]]],\n\n\n\n\n        [[[[[12]]]]],\n\n\n\n\n        [[[[[13]]]]],\n\n\n\n\n        [[[[[14]]]]],\n\n\n\n\n        [[[[[15]]]]]]), tensor([[[[[[[[[ 0]]]]]]]],\n\n\n\n\n\n\n\n        [[[[[[[[ 1]]]]]]]],\n\n\n\n\n\n\n\n        [[[[[[[[ 2]]]]]]]],\n\n\n\n\n\n\n\n        [[[[[[[[ 3]]]]]]]],\n\n\n\n\n\n\n\n        [[[[[[[[ 4]]]]]]]],\n\n\n\n\n\n\n\n        [[[[[[[[ 5]]]]]]]],\n\n\n\n\n\n\n\n        [[[[[[[[ 6]]]]]]]],\n\n\n\n\n\n\n\n        [[[[[[[[ 7]]...
y1 = {Tensor: (30,)} tensor([ 8,  1,  1, 11, 15,  3, 10, 13, 14, 14, 11,  2,  1,  1,  7, 10,  3, 15,\n         7,  8,  1, 13, 15,  7, 12, 12,  6,  2, 11,  6])
y1_help = {Tensor: (2, 1, 1, 1, 1, 1, 30)} tensor([[[[[[[ 8,  1,  1, 11, 15,  3, 10, 13, 14, 14, 11,  2,  1,  1,  7, 10,\n               3, 15,  7,  8,  1, 13, 15,  7, 12, 12,  6,  2, 11,  6]]]]]],\n\n\n\n\n\n        [[[[[[ 8,  1,  1, 11, 15,  3, 10, 13, 14, 14, 11,  2,  1,  1,  7, 10,\n               3, 15,  7,  8,  1, 13, 15,  7, 12, 12,  6,  2, 11,  6]]]]]]])
y2 = {Tensor: (30,)} tensor([ 8,  1,  1, 11, 15,  3, 10, 13, 14, 14, 11,  2,  1,  1,  7, 10,  3, 15,\n         7,  8,  1, 13, 15,  7, 12, 12,  6,  2, 11,  6])
y2_help = {Tensor: (2, 1, 1, 1, 1, 1, 30)} tensor([[[[[[[ 8,  1,  1, 11, 15,  3, 10, 13, 14, 14, 11,  2,  1,  1,  7, 10,\n               3, 15,  7,  8,  1, 13, 15,  7, 12, 12,  6,  2, 11,  6]]]]]],\n\n\n\n\n\n        [[[[[[ 8,  1,  1, 11, 15,  3, 10, 13, 14, 14, 11,  2,  1,  1,  7, 10,\n               3, 15,  7,  8,  1, 13, 15,  7, 12, 12,  6,  2, 11,  6]]]]]]])

Because presumably all the enumerated dimensions were summed out, but then if that was the case, what would have been the use of the x.squeeze(-1) in the hmm?

sorry but i’m not able to do much with that kind of output.

you might benefit from printing out model_trace.format_shapes() like in hmm.py

The first 10 time steps. Thanks.

        2992  Trace Shapes:                              
  Param Sites:                              
         sigma                             1
           muy                            16
      tmuxgivy                         16 16
        thetay                            16
     epsilonyx                         16 16
 Sample Sites:                              
sequences dist                             |
         value                         30  |
      s_0 dist                         30  |
         value                       2  1  |
     x1_0 dist                         30  |
         value                    16 1  1  |
     x2_0 dist                         30  |
         value                 16  1 1  1  |
     y1_0 dist                    16 1 30  |
         value                         30  |
     y2_0 dist                 16  1 1 30  |
         value                         30  |
     x1_1 dist                 16 16 2 30  |
         value              16  1  1 1  1  |
     x2_1 dist                 16 16 2 30  |
         value           16  1  1  1 1  1  |
      s_1 dist                         30  |
         value         2  1  1  1  1 1  1  |
     y1_1 dist              16 16 16 2 30  |
         value                         30  |
     y2_1 dist           16  1 16 16 2 30  |
         value                         30  |
     x1_2 dist         2 16 16  1  1 1 30  |
         value      16 1  1  1  1  1 1  1  |
     x2_2 dist         2 16 16  1  1 1 30  |
         value   16  1 1  1  1  1  1 1  1  |
      s_2 dist                         30  |
         value 2  1  1 1  1  1  1  1 1  1  |
     y1_2 dist      16 2 16 16  1  1 1 30  |
         value                         30  |
     y2_2 dist   16  1 2 16 16  1  1 1 30  |
         value                         30  |
     x1_3 dist 2 16 16 1  1  1  1  1 1 30  |
         value              16  1  1 1  1  |
     x2_3 dist 2 16 16 1  1  1  1  1 1 30  |
         value           16  1  1  1 1  1  |
      s_3 dist                         30  |
         value         2  1  1  1  1 1  1  |
     y1_3 dist 2 16 16 1  1 16  1  1 1 30  |
         value                         30  |
     y2_3 dist 2 16 16 1 16  1  1  1 1 30  |
         value                         30  |
     x1_4 dist         2 16 16  1  1 1 30  |
         value      16 1  1  1  1  1 1  1  |
     x2_4 dist         2 16 16  1  1 1 30  |
         value   16  1 1  1  1  1  1 1  1  |
      s_4 dist                         30  |
         value 2  1  1 1  1  1  1  1 1  1  |
     y1_4 dist      16 2 16 16  1  1 1 30  |
         value                         30  |
     y2_4 dist   16  1 2 16 16  1  1 1 30  |
         value                         30  |
     x1_5 dist 2 16 16 1  1  1  1  1 1 30  |
         value              16  1  1 1  1  |
     x2_5 dist 2 16 16 1  1  1  1  1 1 30  |
         value           16  1  1  1 1  1  |
      s_5 dist                         30  |
         value         2  1  1  1  1 1  1  |
     y1_5 dist 2 16 16 1  1 16  1  1 1 30  |
         value                         30  |
     y2_5 dist 2 16 16 1 16  1  1  1 1 30  |
         value                         30  |
     x1_6 dist         2 16 16  1  1 1 30  |
         value      16 1  1  1  1  1 1  1  |
     x2_6 dist         2 16 16  1  1 1 30  |
         value   16  1 1  1  1  1  1 1  1  |
      s_6 dist                         30  |
         value 2  1  1 1  1  1  1  1 1  1  |
     y1_6 dist      16 2 16 16  1  1 1 30  |
         value                         30  |
     y2_6 dist   16  1 2 16 16  1  1 1 30  |
         value                         30  |
     x1_7 dist 2 16 16 1  1  1  1  1 1 30  |
         value              16  1  1 1  1  |
     x2_7 dist 2 16 16 1  1  1  1  1 1 30  |
         value           16  1  1  1 1  1  |
      s_7 dist                         30  |
         value         2  1  1  1  1 1  1  |
     y1_7 dist 2 16 16 1  1 16  1  1 1 30  |
         value                         30  |
     y2_7 dist 2 16 16 1 16  1  1  1 1 30  |
         value                         30  |
     x1_8 dist         2 16 16  1  1 1 30  |
         value      16 1  1  1  1  1 1  1  |
     x2_8 dist         2 16 16  1  1 1 30  |
         value   16  1 1  1  1  1  1 1  1  |
      s_8 dist                         30  |
         value 2  1  1 1  1  1  1  1 1  1  |
     y1_8 dist      16 2 16 16  1  1 1 30  |
         value                         30  |
     y2_8 dist   16  1 2 16 16  1  1 1 30  |
         value                         30  |
     x1_9 dist 2 16 16 1  1  1  1  1 1 30  |
         value              16  1  1 1  1  |
     x2_9 dist 2 16 16 1  1  1  1  1 1 30  |
         value           16  1  1  1 1  1  |
      s_9 dist                         30  |
         value         2  1  1  1  1 1  1  |
     y1_9 dist 2 16 16 1  1 16  1  1 1 30  |
         value                         30  |
     y2_9 dist 2 16 16 1 16  1  1  1 1 30  |
         value                         30  |
    x1_10 dist         2 16 16  1  1 1 30  |
         value      16 1  1  1  1  1 1  1  |
    x2_10 dist         2 16 16  1  1 1 30  |
         value   16  1 1  1  1  1  1 1  1  |
     s_10 dist                         30  |
         value 2  1  1 1  1  1  1  1 1  1  |
    y1_10 dist      16 2 16 16  1  1 1 30  |
         value                         30  |
    y2_10 dist   16  1 2 16 16  1  1 1 30  |
         value                         30  |

you probably want to put x1_0 etc inside the markov context otherwise i guess you’ll be instantiating larger tensors than necessary

Thanks for the reply. The problem with that is that it is initialized to an initial context (at time step 0) that is somewhat different to the markov process. So the only way I could immediately see of doing that would be to branch on t (at every step) inside the markov block, which I figured was more undesirable?

the code may be uglier but it’ll likely be faster

No problem, in that case, I’ll make that change, thanks.

Internally, n terms of matching indices for the multi-indexing the idea would be to match all of the enumerated dimensions for the same variables and then index on the (enumerated) value of that particular variable that is being fed at that position (that you wanted to index on)? So the index should be taking the leftmost dimension value of that variable? I’m just in the middle of looking over the source code to see how Vindex works.

i’m not sure how helpful words are in the context. probably the most useful thing is to take a very close look at hmm.py and make sure you understand why different latent variables etc. have precisely the shapes they do

Thanks. In that case why do we need to .squeeze(-1) the x in model_0. in the line:

with tones_plate:
                pyro.sample(
                    "y_{}_{}".format(i, t),
                    dist.Bernoulli(probs_y[x.squeeze(-1)]),
                    obs=sequence[t],
                )

Since this basically removes the 1 at the end of the x and converts it into a 1D tensor right? Which presumably prevents the Bernoulli sample retaining this shape, is it that then the summation that results from the observation would have problems dealing with this?

And is it possible to tell from the output above whether a line like:

probs_y2_t = Vindex(probs_y)[y2_help, x2_help, x2]

in my model above, treating the tensors like values, is going to work as intended (using Vindex or raw indexing)?

I mean is there an ordering of the variables or some gerrymandering sometimes that is necessary to get it to work seamlessly typically like in the hmm example?

OK, looking at the shape of probs_y I get why you have done that since the hidden dimension is the first dimension in the hmm.py example. But still not how I would tell if the multiple indexing is going to work out. And when I should use Vindex (from the tutorials it was so it would work with and without enumeration; sampling and inference).

probs_y[x.squeeze(-1)] includes a squeeze so that the result does not include an extraneous 1 dimension. in particular so that the 16 dimension for x_0 lines up with the first available enumeration dimension (the first dimension left of the outermost plate)

But the 1 is what gives x the correct position for enumeration? Because x enumerates at dimension -2 and -3, why would I want it to match the sequences plate (which doesn’t appear in any of the dimensions output below the example in any case)?

Apologies I am finding this somewhat confusing.

model_0 only has a single vectorized plate, namely tones_plate

You mean one parallel plate, it has an outer sequential plate (which is why it doesn’t influence the dimensions). So why does it need to match the parallel plate dimension?

Since I mean in one incarnation (x_0) it does line up with the dimension to the left of the tones plate dimension already (with the dimension of 1) and in the other (x_1 for instance) it is one further to the left still.

compare model0 and model1 and this comment:

Notice that we’re now using dim=-2 as a batch dimension (of size 10), and that the enumeration dimensions are now dims -3 and -4.