In this example, the guide step was implemented as follows,
svi = pyro.infer.SVI(
model=conditioned_data_model,
guide=guide,
optim=pyro.optim.SGD({"lr": 0.001, "momentum": 0.8}),
loss=pyro.infer.Trace_ELBO(),
)
params_prior = [prior.concentration1, prior.concentration0]
# Iterate over all the data and store results
losses, alpha, beta = [], [], []
pyro.clear_param_store()
num_steps = 3000
for t in range(num_steps):
losses.append(svi.step(params_prior))
alpha.append(pyro.param("alpha").item())
beta.append(pyro.param("beta").item())
posterior_vi = dist.Beta(alpha[-1], beta[-1])
But running it will cause the error with the following message. How to fix it based on the given hint?
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Input In [9], in <cell line: 15>()
14 num_steps = 3000
15 for t in range(num_steps):
---> 16 losses.append(svi.step(params_prior))
17 alpha.append(pyro.param("alpha").item())
18 beta.append(pyro.param("beta").item())
File C:\ProgramData\Miniconda3\envs\geo\lib\site-packages\pyro\infer\svi.py:145, in SVI.step(self, *args, **kwargs)
143 # get loss and compute gradients
144 with poutine.trace(param_only=True) as param_capture:
--> 145 loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
147 params = set(
148 site["value"].unconstrained() for site in param_capture.trace.nodes.values()
149 )
151 # actually perform gradient steps
152 # torch.optim objects gets instantiated for any params that haven't been seen yet
File C:\ProgramData\Miniconda3\envs\geo\lib\site-packages\pyro\infer\trace_elbo.py:140, in Trace_ELBO.loss_and_grads(self, model, guide, *args, **kwargs)
138 loss = 0.0
139 # grab a trace from the generator
--> 140 for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs):
141 loss_particle, surrogate_loss_particle = self._differentiable_loss_particle(
142 model_trace, guide_trace
143 )
144 loss += loss_particle / self.num_particles
File C:\ProgramData\Miniconda3\envs\geo\lib\site-packages\pyro\infer\elbo.py:182, in ELBO._get_traces(self, model, guide, args, kwargs)
180 else:
181 for i in range(self.num_particles):
--> 182 yield self._get_trace(model, guide, args, kwargs)
File C:\ProgramData\Miniconda3\envs\geo\lib\site-packages\pyro\infer\trace_elbo.py:57, in Trace_ELBO._get_trace(self, model, guide, args, kwargs)
52 def _get_trace(self, model, guide, args, kwargs):
53 """
54 Returns a single trace from the guide, and the model that is run
55 against it.
56 """
---> 57 model_trace, guide_trace = get_importance_trace(
58 "flat", self.max_plate_nesting, model, guide, args, kwargs
59 )
60 if is_validation_enabled():
61 check_if_enumerated(guide_trace)
File C:\ProgramData\Miniconda3\envs\geo\lib\site-packages\pyro\infer\enum.py:80, in get_importance_trace(graph_type, max_plate_nesting, model, guide, args, kwargs, detach)
78 for site in model_trace.nodes.values():
79 if site["type"] == "sample":
---> 80 check_site_shape(site, max_plate_nesting)
81 for site in guide_trace.nodes.values():
82 if site["type"] == "sample":
File C:\ProgramData\Miniconda3\envs\geo\lib\site-packages\pyro\util.py:437, in check_site_shape(site, max_plate_nesting)
433 for actual_size, expected_size in zip_longest(
434 reversed(actual_shape), reversed(expected_shape), fillvalue=1
435 ):
436 if expected_size != -1 and expected_size != actual_size:
--> 437 raise ValueError(
438 "\n ".join(
439 [
440 'at site "{}", invalid log_prob shape'.format(site["name"]),
441 "Expected {}, actual {}".format(expected_shape, actual_shape),
442 "Try one of the following fixes:",
443 "- enclose the batched tensor in a with pyro.plate(...): context",
444 "- .to_event(...) the distribution being sampled",
445 "- .permute() data dimensions",
446 ]
447 )
448 )
450 # Check parallel dimensions on the left of max_plate_nesting.
451 enum_dim = site["infer"].get("_enumerate_dim")
ValueError: at site "data_dist", invalid log_prob shape
Expected [], actual [100, 1]
Try one of the following fixes:
- enclose the batched tensor in a with pyro.plate(...): context
- .to_event(...) the distribution being sampled
- .permute() data dimensions