I am trying to run the following code:
from __future__ import print_function
import math
import os
import torch
import torch.distributions.constraints as constraints
import pyro
from pyro.optim import Adam
from pyro.infer import SVI, Trace_ELBO
import pyro.distributions as dist
import HiddenStatesModel as HSM
import numpy as np
N = 6
n_steps = 100
pyro.set_rng_seed(101)
data = np.random.random([2,1400])
def model(data):
"""Basically p(x|z)p(z)"""
# define the hyperparameters that control the MVN prior
loc = torch.ones(2*N)
scale_tril = 10.*torch.eye(2*N)
# sample f from the beta prior
f = pyro.sample("latent_mapping",
dist.MultivariateNormal(loc, scale_tril=scale_tril))
print('f.shape = ', f.shape)
# loop over the observed data
for i in range(len(data)):
# observe datapoint i using the bernoulli likelihood
LocCovMatrix = torch.eye(2*N)
pyro.sample("obs_{}".format(i),
dist.MultivariateNormal(f, scale_tril=LocCovMatrix),
obs=data[:, i])
def guide(data):
GuideLoc = pyro.param("mu_q", torch.ones(2*N)).view(12) # Remove singleton dim
GuideCovMat = pyro.param("ssigma_q", torch.eye(2*N),
constraint=constraints.lower_cholesky)
# sample latent_fairness from the distribution Beta(alpha_q, beta_q)
pyro.sample("latent_mapping",
dist.MultivariateNormal(GuideLoc, scale_tril=GuideCovMat))
# setup the optimizer
adam_params = {"lr": 0.0005, "betas": (0.90, 0.999)}
optimizer = Adam(adam_params)
# setup the inference algorithm
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
# do gradient steps
for step in range(n_steps):
# Passes data to model() and to guide(). Returns an estimate of -ELBO.
loss = svi.step(data)
#print(loss)
if step % 100 == 0:
print('.', end='')
# grab the learned variational parameters
mu_q = pyro.param("mu_q").item()
ssigma_q = pyro.param("ssigma_q").item()
print("\nbased on the data and our prior belief, the mapping " +
"L is %.3f +- %.3f" % (mu_q, np.sqrt(ssigma_q)))
It returns the following error:
RuntimeError: expected type torch.DoubleTensor but got torch.FloatTensor
I did not explicitly state the datatypes anywhere, and yet it produces the above error. This seems to me like it is assigning conflicting data types of its own. Where is my mistake?
Best,
Sascha