Thanks @fehiepsi. I’ll post my question here then.
I’m trying to get my first stochastic process model working. I adapted the code from here to suit my example problem which is simply two Gaussians with a discrete probability of the sample coming from one or the other. Right now, I’m just trying to do something simple like estimate the probability distribution of the output of this process.
Here’s my code:
import torch
import pyro
import pyro.distributions as dist
from pyro.infer.mcmc import HMC, MCMC
# Actual data sample
observations = torch.tensor(
[0.00528813, -0.00589001, -1.20608593, 0.00190794,
0.89052784, 0.66690464, 0.57295968, 0.02605967]
)
# Define the process
def model(observations):
a_prior = dist.Beta(2, 2)
a = pyro.sample("a", a_prior)
c = pyro.sample('c', dist.Bernoulli(a))
if c.item() == 1.0:
my_dist = dist.Normal(0.785, 1.0)
else:
my_dist = dist.Normal(0.0, 0.01)
for i, observation in enumerate(observations):
measurement = pyro.sample(f'obs_{i}', my_dist, obs=observation)
# Clear parameters
pyro.clear_param_store()
# Define the MCMC kernel function
my_kernel = HMC(model)
# Define the MCMC algorithm
my_mcmc = MCMC(my_kernel,
num_samples=5000,
warmup_steps=50)
# Run the algorithm, passing the observations
my_mcmc.run(observations)
The exception raised is:
<ipython-input-2-a668622a0fb9> in model(observations)
11 a = pyro.sample("a", a_prior)
12 c = pyro.sample('c', dist.Bernoulli(a))
---> 13 if c.item() == 1.0:
14 my_dist = dist.Normal(0.785, 1.0)
15 else:
ValueError: only one element tensors can be converted to Python scalars
Trace Shapes:
Param Sites:
Sample Sites:
a dist |
value |
c dist |
value 2 |
I had a look at c
using the debugger and for some reason it has two elements the second time model() is called:
tensor([0., 1.])
What is causing this? I wanted it to be a simple scalar having the values 0 or 1.
As a further test, the condition statement works fine when taking samples in the normal way:
# Conditional switch test
a_prior = dist.Beta(2, 2)
a = pyro.sample("a", a_prior)
for i in range(5):
c = pyro.sample('c', dist.Bernoulli(a))
if c.item() == 1.0:
print(1, end=' ')
else:
print(0, end=' ')
# 0 0 1 0 0