Error fix with hint of - enclose the batched tensor in a with pyro.plate(...): context

In this example, the guide step was implemented as follows,

svi = pyro.infer.SVI(
    model=conditioned_data_model,
    guide=guide,
    optim=pyro.optim.SGD({"lr": 0.001, "momentum": 0.8}),
    loss=pyro.infer.Trace_ELBO(),
)

params_prior = [prior.concentration1, prior.concentration0]

# Iterate over all the data and store results
losses, alpha, beta = [], [], []
pyro.clear_param_store()

num_steps = 3000
for t in range(num_steps):
    losses.append(svi.step(params_prior))
    alpha.append(pyro.param("alpha").item())
    beta.append(pyro.param("beta").item())

posterior_vi = dist.Beta(alpha[-1], beta[-1])

But running it will cause the error with the following message. How to fix it based on the given hint?

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Input In [9], in <cell line: 15>()
     14 num_steps = 3000
     15 for t in range(num_steps):
---> 16     losses.append(svi.step(params_prior))
     17     alpha.append(pyro.param("alpha").item())
     18     beta.append(pyro.param("beta").item())

File C:\ProgramData\Miniconda3\envs\geo\lib\site-packages\pyro\infer\svi.py:145, in SVI.step(self, *args, **kwargs)
    143 # get loss and compute gradients
    144 with poutine.trace(param_only=True) as param_capture:
--> 145     loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
    147 params = set(
    148     site["value"].unconstrained() for site in param_capture.trace.nodes.values()
    149 )
    151 # actually perform gradient steps
    152 # torch.optim objects gets instantiated for any params that haven't been seen yet

File C:\ProgramData\Miniconda3\envs\geo\lib\site-packages\pyro\infer\trace_elbo.py:140, in Trace_ELBO.loss_and_grads(self, model, guide, *args, **kwargs)
    138 loss = 0.0
    139 # grab a trace from the generator
--> 140 for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs):
    141     loss_particle, surrogate_loss_particle = self._differentiable_loss_particle(
    142         model_trace, guide_trace
    143     )
    144     loss += loss_particle / self.num_particles

File C:\ProgramData\Miniconda3\envs\geo\lib\site-packages\pyro\infer\elbo.py:182, in ELBO._get_traces(self, model, guide, args, kwargs)
    180 else:
    181     for i in range(self.num_particles):
--> 182         yield self._get_trace(model, guide, args, kwargs)

File C:\ProgramData\Miniconda3\envs\geo\lib\site-packages\pyro\infer\trace_elbo.py:57, in Trace_ELBO._get_trace(self, model, guide, args, kwargs)
     52 def _get_trace(self, model, guide, args, kwargs):
     53     """
     54     Returns a single trace from the guide, and the model that is run
     55     against it.
     56     """
---> 57     model_trace, guide_trace = get_importance_trace(
     58         "flat", self.max_plate_nesting, model, guide, args, kwargs
     59     )
     60     if is_validation_enabled():
     61         check_if_enumerated(guide_trace)

File C:\ProgramData\Miniconda3\envs\geo\lib\site-packages\pyro\infer\enum.py:80, in get_importance_trace(graph_type, max_plate_nesting, model, guide, args, kwargs, detach)
     78 for site in model_trace.nodes.values():
     79     if site["type"] == "sample":
---> 80         check_site_shape(site, max_plate_nesting)
     81 for site in guide_trace.nodes.values():
     82     if site["type"] == "sample":

File C:\ProgramData\Miniconda3\envs\geo\lib\site-packages\pyro\util.py:437, in check_site_shape(site, max_plate_nesting)
    433 for actual_size, expected_size in zip_longest(
    434     reversed(actual_shape), reversed(expected_shape), fillvalue=1
    435 ):
    436     if expected_size != -1 and expected_size != actual_size:
--> 437         raise ValueError(
    438             "\n  ".join(
    439                 [
    440                     'at site "{}", invalid log_prob shape'.format(site["name"]),
    441                     "Expected {}, actual {}".format(expected_shape, actual_shape),
    442                     "Try one of the following fixes:",
    443                     "- enclose the batched tensor in a with pyro.plate(...): context",
    444                     "- .to_event(...) the distribution being sampled",
    445                     "- .permute() data dimensions",
    446                 ]
    447             )
    448         )
    450 # Check parallel dimensions on the left of max_plate_nesting.
    451 enum_dim = site["infer"].get("_enumerate_dim")

ValueError: at site "data_dist", invalid log_prob shape
  Expected [], actual [100, 1]
  Try one of the following fixes:
  - enclose the batched tensor in a with pyro.plate(...): context
  - .to_event(...) the distribution being sampled
  - .permute() data dimensions

‚Äč

I would first check the .batch_shape and .event_shape of the data_dist. Tensor shapes in Pyro tutorial might be helpful to understand the given hint.