jprah
April 9, 2020, 6:58pm
1
I’ve tried to follow the procedure of the GMM Example and wanted to evaluate Classification of new-data by the learned GMM model, using Enumeration in the full_guide() described in the tutorial.
I obtained an error, which I have raised as a github-bug, though I’m not certain if its truly a bug, or if I have missed out on something.
opened 06:43PM - 09 Apr 20 UTC
question
discussion
### Issue Description
While following the tutorial on [Gaussian Mixture Models]… (http://pyro.ai/examples/gmm.html), I noted the following error is generated when you feed the full_guide() the new_data that was created earlier.
ValueError: Shape mismatch inside plate('fg_data') at site assignment_new dim -1, 180 vs 5
```
/opt/conda/lib/python3.7/site-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
9 def _context_wrap(context, fn, *args, **kwargs):
10 with context:
---> 11 return fn(*args, **kwargs)
12
13
<ipython-input-197-9648525bdb5f> in full_guide(fg_data)
22 assignment_probs = pyro.param('assignment_probs', torch.ones(len(fg_data), K) / K,
23 constraint=constraints.unit_interval)
---> 24 pyro.sample('assignment_new', dist.Categorical(assignment_probs))
25
26 """
/opt/conda/lib/python3.7/site-packages/pyro/primitives.py in sample(name, fn, *args, **kwargs)
111 msg["is_observed"] = True
112 # apply the stack and return its return value
--> 113 apply_stack(msg)
114 return msg["value"]
115
/opt/conda/lib/python3.7/site-packages/pyro/poutine/runtime.py in apply_stack(initial_msg)
191 pointer = pointer + 1
192
--> 193 frame._process_message(msg)
194
195 if msg["stop"]:
/opt/conda/lib/python3.7/site-packages/pyro/poutine/plate_messenger.py in _process_message(self, msg)
13 def _process_message(self, msg):
14 super()._process_message(msg)
---> 15 return BroadcastMessenger._pyro_sample(msg)
16
17 def __enter__(self):
/opt/conda/lib/python3.7/contextlib.py in inner(*args, **kwds)
72 def inner(*args, **kwds):
73 with self._recreate_cm():
---> 74 return func(*args, **kwds)
75 return inner
76
/opt/conda/lib/python3.7/site-packages/pyro/poutine/broadcast_messenger.py in _pyro_sample(msg)
57 if target_batch_shape[f.dim] is not None and target_batch_shape[f.dim] != f.size:
58 raise ValueError("Shape mismatch inside plate('{}') at site {} dim {}, {} vs {}".format(
---> 59 f.name, msg['name'], f.dim, f.size, target_batch_shape[f.dim]))
60 target_batch_shape[f.dim] = f.size
61 # Starting from the right, if expected size is None at an index,
```
ValueError: Shape mismatch inside plate('fg_data') at site assignment_new dim -1, 180 vs 5
The new_assignment parameter shouldn't be having the shape 5, which was possibly obtained from the training dataset which was of length 5.
### Environment
For any bugs, please provide the following:
- OS : Ubuntu 16.04
- python version: 3.7.7 [GCC 7.3.0]
- PyTorch version: 1.4.0
- Pyro version: 1.3.1
### Code Snippet
Provide any relevant code snippets and commands run to replicate the issue.
```python
new_data = torch.arange(5.5, 6.0, 0.005)
@config_enumerate
def full_guide(fg_data):
# Global variables.
with poutine.block(hide_types=["param"]): # Keep our learned values of global parameters.
global_guide(fg_data)
# Local variables.
with pyro.plate('fg_data', len(fg_data)):
assignment_probs = pyro.param('assignment_probs', torch.ones(len(fg_data), K) / K,
constraint=constraints.unit_interval)
pyro.sample('assignment_new', dist.Categorical(assignment_probs))
full_guide(new_data)
```
I would really appreciate it if you could guide me on using the Enumeration in Guide methodology for a new-dataset.
Thanks.