I am working on the Dsprites Dataset and have created a Causal Variational Auto Encoder. I am trying to answer counterfactual queries like “given this image of a heart with this orientation, position, and scale, what would it have looked like if it were a square?”
While building the Structural Causal Model and conditioning on it I am getting the runtime error of log_vml_cpu not implemented for 'Long'
while running the Inference. This looks like some GPU to CPU issue of Pyro or Pytorch but I am not sure. Here is the error:
RuntimeError Traceback (most recent call last)
<ipython-input-35-a6f1c970e088> in <module>()
13 #posterior = MCMC(kernel, num_samples=1000, warmup_steps=50)
---> 15 posterior = pyro.infer.Importance(conditioned_model, num_samples = 1).run(vae, mu, sigma)
16 #posterior.run(vae, mu, sigma)
3 frames
/usr/local/lib/python3.6/dist-packages/pyro/infer/abstract_infer.py in run(self, *args, **kwargs)
222 self._reset()
223 with poutine.block():
--> 224 for i, vals in enumerate(self._traces(*args, **kwargs)):
225 if len(vals) == 2:
226 chain_id = 0
/usr/local/lib/python3.6/dist-packages/pyro/infer/importance.py in _traces(self, *args, **kwargs)
48 model_trace = poutine.trace(
49 poutine.replay(self.model, trace=guide_trace)).get_trace(*args, **kwargs)
---> 50 log_weight = model_trace.log_prob_sum() - guide_trace.log_prob_sum()
51 yield (model_trace, log_weight)
/usr/local/lib/python3.6/dist-packages/pyro/poutine/trace_struct.py in log_prob_sum(self, site_filter)
189 else:
190 try:
--> 191 log_p = site["fn"].log_prob(site["value"], *site["args"], **site["kwargs"])
192 except ValueError:
193 _, exc_value, traceback = sys.exc_info()
/usr/local/lib/python3.6/dist-packages/pyro/distributions/delta.py in log_prob(self, x)
58 def log_prob(self, x):
59 v = self.v.expand(self.shape())
---> 60 log_prob = (x == v).type(x.dtype).log()
61 log_prob = sum_rightmost(log_prob, self.event_dim)
62 return log_prob + self.log_density
RuntimeError: log_vml_cpu not implemented for 'Long'
Here is the code of my SCM:
from pyro.infer.importance import Importance
from pyro.infer.mcmc import MCMC
from pyro.infer.mcmc.nuts import HMC
intervened_model = pyro.do(SCM, data={"Y_shape": torch.tensor(1)})
conditioned_model = pyro.condition(intervened_model, data={
"X": recon_x1,
"Y_shape": torch.tensor(0),
#kernel = HMC(conditioned_model, step_size=0.8, num_steps=4)
#posterior = MCMC(kernel, num_samples=1000, warmup_steps=50)
posterior = pyro.infer.Importance(conditioned_model, num_samples = 1).run(vae, mu, sigma)
#posterior.run(vae, mu, sigma)
marginal = posterior.EmpiricalMarginal(posterior, )
result = []
for i in range(10):
trace = posterior()
x = trace.nodes['Nx']['value']
y = trace.nodes['Ny']['value']
z = trace.nodes['Nz']['value']
con_obj = pyro.condition(intervened_model, data = {"Nx": x,"Ny": y, "Nz": z})
# result.append(con_obj()[2])
# recon_x2,y2,z2 = con_obj(vae, mu, sigma)
# print(y2)
# recon_check(recon_x1.reshape(-1, 64, 64)[0], recon_x2.reshape(-1, 64, 64)[0])
Please let me know how to debug this or what the issue is. Highly appreciated