The error in the first version of codes in this post is as follows:
Batch - 1
Batch - 2
Batch - 3
Batch - 4
Batch - 5
Batch - 6
Batch - 7
Batch - 8
Batch - 9
Batch - 10
Batch - 11
Batch - 12
Batch - 13
Batch - 14
Batch - 15
Batch - 16
Batch - 17
Batch - 18
Batch - 19
Batch - 20
Batch - 21
Batch - 22
Traceback (most recent call last):
File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\poutine\trace_messenger.py", line 174, in __call__
ret = self.fn(*args, **kwargs)
File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\poutine\messenger.py", line 12, in _context_wrap
return fn(*args, **kwargs)
File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\nn\module.py", line 426, in __call__
return super().__call__(*args, **kwargs)
File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\torch\nn\modules\module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "E:/科研/科研探索/20220901/test.py", line 47, in forward
latent_variables = pyro.sample('latent_variables', dist.Normal(loc, scale).to_event(1))
File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\primitives.py", line 163, in sample
apply_stack(msg)
File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\poutine\runtime.py", line 213, in apply_stack
frame._process_message(msg)
File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\poutine\plate_messenger.py", line 19, in _process_message
return BroadcastMessenger._pyro_sample(msg)
File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\contextlib.py", line 52, in inner
return func(*args, **kwds)
File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\poutine\broadcast_messenger.py", line 71, in _pyro_sample
target_batch_shape[f.dim],
ValueError: Shape mismatch inside plate('latent') at site latent_variables dim -1, 8 vs 2
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "E:/科研/科研探索/20220901/test.py", line 99, in <module>
batch_loss += svi.step(x_data, y_data)
File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\infer\svi.py", line 145, in step
loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\infer\trace_elbo.py", line 140, in loss_and_grads
for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs):
File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\infer\elbo.py", line 182, in _get_traces
yield self._get_trace(model, guide, args, kwargs)
File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\infer\trace_elbo.py", line 58, in _get_trace
"flat", self.max_plate_nesting, model, guide, args, kwargs
File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\infer\enum.py", line 67, in get_importance_trace
).get_trace(*args, **kwargs)
File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\poutine\trace_messenger.py", line 198, in get_trace
self(*args, **kwargs)
File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\poutine\trace_messenger.py", line 180, in __call__
raise exc from e
File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\poutine\trace_messenger.py", line 174, in __call__
ret = self.fn(*args, **kwargs)
File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\poutine\messenger.py", line 12, in _context_wrap
return fn(*args, **kwargs)
File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\nn\module.py", line 426, in __call__
return super().__call__(*args, **kwargs)
File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\torch\nn\modules\module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "E:/科研/科研探索/20220901/test.py", line 47, in forward
latent_variables = pyro.sample('latent_variables', dist.Normal(loc, scale).to_event(1))
File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\primitives.py", line 163, in sample
apply_stack(msg)
File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\poutine\runtime.py", line 213, in apply_stack
frame._process_message(msg)
File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\poutine\plate_messenger.py", line 19, in _process_message
return BroadcastMessenger._pyro_sample(msg)
File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\contextlib.py", line 52, in inner
return func(*args, **kwds)
File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\poutine\broadcast_messenger.py", line 71, in _pyro_sample
target_batch_shape[f.dim],
ValueError: Shape mismatch inside plate('latent') at site latent_variables dim -1, 8 vs 2
Trace Shapes:
Param Sites:
Sample Sites:
sigma dist |
value |
linear.weight dist | 1 3
value | 1 3
linear.bias dist | 1
value | 1
latent dist |
value 8 |
Process finished with exit code 1
After I change the x.shape in the first plate, the error is as follows:
D:\Users\83451\anaconda3\envs\BaseEnv\python.exe E:/科研/科研探索/20220901/test.py
Batch - 1
Traceback (most recent call last):
File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\poutine\trace_messenger.py", line 174, in __call__
ret = self.fn(*args, **kwargs)
File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\poutine\messenger.py", line 12, in _context_wrap
return fn(*args, **kwargs)
File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\poutine\messenger.py", line 12, in _context_wrap
return fn(*args, **kwargs)
File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\nn\module.py", line 426, in __call__
return super().__call__(*args, **kwargs)
File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\torch\nn\modules\module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "E:/科研/科研探索/20220901/test.py", line 47, in forward
latent_variables = pyro.sample('latent_variables', dist.Normal(loc, scale).to_event(1))
File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\primitives.py", line 163, in sample
apply_stack(msg)
File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\poutine\runtime.py", line 213, in apply_stack
frame._process_message(msg)
File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\poutine\plate_messenger.py", line 19, in _process_message
return BroadcastMessenger._pyro_sample(msg)
File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\contextlib.py", line 52, in inner
return func(*args, **kwds)
File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\poutine\broadcast_messenger.py", line 71, in _pyro_sample
target_batch_shape[f.dim],
ValueError: Shape mismatch inside plate('latent') at site latent_variables dim -1, 3 vs 8
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\poutine\trace_messenger.py", line 174, in __call__
ret = self.fn(*args, **kwargs)
File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\nn\module.py", line 426, in __call__
return super().__call__(*args, **kwargs)
File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\torch\nn\modules\module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\infer\autoguide\guides.py", line 759, in forward
self._setup_prototype(*args, **kwargs)
File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\infer\autoguide\guides.py", line 935, in _setup_prototype
super()._setup_prototype(*args, **kwargs)
File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\infer\autoguide\guides.py", line 636, in _setup_prototype
super()._setup_prototype(*args, **kwargs)
File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\infer\autoguide\guides.py", line 158, in _setup_prototype
*args, **kwargs
File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\poutine\messenger.py", line 12, in _context_wrap
return fn(*args, **kwargs)
File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\poutine\trace_messenger.py", line 198, in get_trace
self(*args, **kwargs)
File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\poutine\trace_messenger.py", line 180, in __call__
raise exc from e
File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\poutine\trace_messenger.py", line 174, in __call__
ret = self.fn(*args, **kwargs)
File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\poutine\messenger.py", line 12, in _context_wrap
return fn(*args, **kwargs)
File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\poutine\messenger.py", line 12, in _context_wrap
return fn(*args, **kwargs)
File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\nn\module.py", line 426, in __call__
return super().__call__(*args, **kwargs)
File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\torch\nn\modules\module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "E:/科研/科研探索/20220901/test.py", line 47, in forward
latent_variables = pyro.sample('latent_variables', dist.Normal(loc, scale).to_event(1))
File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\primitives.py", line 163, in sample
apply_stack(msg)
File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\poutine\runtime.py", line 213, in apply_stack
frame._process_message(msg)
File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\poutine\plate_messenger.py", line 19, in _process_message
return BroadcastMessenger._pyro_sample(msg)
File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\contextlib.py", line 52, in inner
return func(*args, **kwds)
File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\poutine\broadcast_messenger.py", line 71, in _pyro_sample
target_batch_shape[f.dim],
ValueError: Shape mismatch inside plate('latent') at site latent_variables dim -1, 3 vs 8
Trace Shapes:
Param Sites:
Sample Sites:
sigma dist |
value |
linear.weight dist | 1 3
value | 1 3
linear.bias dist | 1
value | 1
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "E:/科研/科研探索/20220901/test.py", line 99, in <module>
batch_loss += svi.step(x_data, y_data)
File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\infer\svi.py", line 145, in step
loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\infer\trace_elbo.py", line 140, in loss_and_grads
for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs):
File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\infer\elbo.py", line 182, in _get_traces
yield self._get_trace(model, guide, args, kwargs)
File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\infer\trace_elbo.py", line 58, in _get_trace
"flat", self.max_plate_nesting, model, guide, args, kwargs
File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\infer\enum.py", line 61, in get_importance_trace
*args, **kwargs
File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\poutine\trace_messenger.py", line 198, in get_trace
self(*args, **kwargs)
File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\poutine\trace_messenger.py", line 180, in __call__
raise exc from e
File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\poutine\trace_messenger.py", line 174, in __call__
ret = self.fn(*args, **kwargs)
File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\nn\module.py", line 426, in __call__
return super().__call__(*args, **kwargs)
File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\torch\nn\modules\module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\infer\autoguide\guides.py", line 759, in forward
self._setup_prototype(*args, **kwargs)
File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\infer\autoguide\guides.py", line 935, in _setup_prototype
super()._setup_prototype(*args, **kwargs)
File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\infer\autoguide\guides.py", line 636, in _setup_prototype
super()._setup_prototype(*args, **kwargs)
File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\infer\autoguide\guides.py", line 158, in _setup_prototype
*args, **kwargs
File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\poutine\messenger.py", line 12, in _context_wrap
return fn(*args, **kwargs)
File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\poutine\trace_messenger.py", line 198, in get_trace
self(*args, **kwargs)
File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\poutine\trace_messenger.py", line 180, in __call__
raise exc from e
File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\poutine\trace_messenger.py", line 174, in __call__
ret = self.fn(*args, **kwargs)
File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\poutine\messenger.py", line 12, in _context_wrap
return fn(*args, **kwargs)
File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\poutine\messenger.py", line 12, in _context_wrap
return fn(*args, **kwargs)
File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\nn\module.py", line 426, in __call__
return super().__call__(*args, **kwargs)
File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\torch\nn\modules\module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "E:/科研/科研探索/20220901/test.py", line 47, in forward
latent_variables = pyro.sample('latent_variables', dist.Normal(loc, scale).to_event(1))
File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\primitives.py", line 163, in sample
apply_stack(msg)
File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\poutine\runtime.py", line 213, in apply_stack
frame._process_message(msg)
File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\poutine\plate_messenger.py", line 19, in _process_message
return BroadcastMessenger._pyro_sample(msg)
File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\contextlib.py", line 52, in inner
return func(*args, **kwds)
File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\poutine\broadcast_messenger.py", line 71, in _pyro_sample
target_batch_shape[f.dim],
ValueError: Shape mismatch inside plate('latent') at site latent_variables dim -1, 3 vs 8
Trace Shapes:
Param Sites:
Sample Sites:
sigma dist |
value |
linear.weight dist | 1 3
value | 1 3
linear.bias dist | 1
value | 1
Trace Shapes:
Param Sites:
Sample Sites:
Process finished with exit code 1