Getting value error : at site "b", invalid log_prob shape. If anyone could please help. I am trying to run code from this repo : https://github.com/nyu-mll/nlu-test-sets

(irt_project) charchit@ainode01:~/test_irt_nlu_2/nlu-test-sets$ bash irt_scripts/estimate_irt_params.sh
Alpha Std 0.3, Diff Guess Std 1.0
Missing:
 set()
Use all examples:
abductive-nli num test items:  766
adversarial-nli num test items:  3200
arc-challenge num test items:  1172
arc-easy num test items:  2376
arct num test items:  444
boolq num test items:  1635
cb num test items:  28
commonsenseqa num test items:  611
copa num test items:  50
cosmosqa num test items:  1493
hellaswag num test items:  5021
mcscript num test items:  3610
mctaco num test items:  1332
mnli num test items:  9824
mrqa-nq num test items:  6418
mutual-plus num test items:  443
mutual num test items:  443
newsqa num test items:  4293
piqa num test items:  919
qamr num test items:  18770
quail num test items:  556
quoref num test items:  1209
rte num test items:  139
snli num test items:  9824
socialiqa num test items:  977
squad_v2 num test items:  6198
wic num test items:  319
winogrande num test items:  634
wsc num test items:  52
Extracted from 29 files
Collected response patterns for
        abductive-nli
        adversarial-nli
        arc-challenge
        arc-easy
        arct
        boolq
        cb
        commonsenseqa
        copa
        cosmosqa
        hellaswag
        mcscript
        mctaco
        mnli
        mrqa-nq
        mutual-plus
        mutual
        newsqa
        piqa
        qamr
        quail
        quoref
        rte
        snli
        socialiqa
        squad_v2
        wic
        winogrande
        wsc
Total number of items is 82756
Total combined items is 82756
Item Param Std is 1.00
Std Overwrite 0.30
Std Overwrite 1.00
elbo loss:   0%|                                                                                                                                                        | 0/500 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "/home/charchit/test_irt_nlu_2/nlu-test-sets/irt_scripts/variational_irt.py", line 604, in <module>
    main(args)
  File "/home/charchit/test_irt_nlu_2/nlu-test-sets/irt_scripts/variational_irt.py", line 474, in main
    _ = train(
  File "/home/charchit/test_irt_nlu_2/nlu-test-sets/irt_scripts/variational_irt.py", line 318, in train
    elbo_loss = svi_kernel.step(weights, data_)
  File "/home/charchit/test_irt_nlu_2/nlu-test-sets/irt_scripts/weighted_ELBO.py", line 173, in step
    loss = self.loss_and_grads(self.model, self.guide, weights, *args, **kwargs)
  File "/home/charchit/test_irt_nlu_2/nlu-test-sets/irt_scripts/weighted_ELBO.py", line 115, in loss_and_grads
    for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs):
  File "/home/charchit/miniconda3/envs/irt_project/lib/python3.9/site-packages/pyro/infer/elbo.py", line 182, in _get_traces
    yield self._get_trace(model, guide, args, kwargs)
  File "/home/charchit/test_irt_nlu_2/nlu-test-sets/irt_scripts/weighted_ELBO.py", line 30, in _get_trace
    model_trace, guide_trace = get_importance_trace(
  File "/home/charchit/miniconda3/envs/irt_project/lib/python3.9/site-packages/pyro/infer/enum.py", line 80, in get_importance_trace
    check_site_shape(site, max_plate_nesting)
  File "/home/charchit/miniconda3/envs/irt_project/lib/python3.9/site-packages/pyro/util.py", line 437, in check_site_shape
    raise ValueError(
ValueError: at site "b", invalid log_prob shape
  Expected [], actual [82756]
  Try one of the following fixes:
  - enclose the batched tensor in a with pyro.plate(...): context
  - .to_event(...) the distribution being sampled
  - .permute() data dimensions