RuntimeError: log_vml_cpu not implemented for 'Long'

Hello,
I am working on the Dsprites Dataset and have created a Causal Variational Auto Encoder. I am trying to answer counterfactual queries like “given this image of a heart with this orientation, position, and scale, what would it have looked like if it were a square?”

While building the Structural Causal Model and conditioning on it I am getting the runtime error of log_vml_cpu not implemented for 'Long' while running the Inference. This looks like some GPU to CPU issue of Pyro or Pytorch but I am not sure. Here is the error:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-35-a6f1c970e088> in <module>()
     13 #posterior = MCMC(kernel, num_samples=1000, warmup_steps=50)
     14 
---> 15 posterior = pyro.infer.Importance(conditioned_model, num_samples = 1).run(vae, mu, sigma)
     16 #posterior.run(vae, mu, sigma)
     17 

3 frames
/usr/local/lib/python3.6/dist-packages/pyro/infer/abstract_infer.py in run(self, *args, **kwargs)
    222         self._reset()
    223         with poutine.block():
--> 224             for i, vals in enumerate(self._traces(*args, **kwargs)):
    225                 if len(vals) == 2:
    226                     chain_id = 0

/usr/local/lib/python3.6/dist-packages/pyro/infer/importance.py in _traces(self, *args, **kwargs)
     48             model_trace = poutine.trace(
     49                 poutine.replay(self.model, trace=guide_trace)).get_trace(*args, **kwargs)
---> 50             log_weight = model_trace.log_prob_sum() - guide_trace.log_prob_sum()
     51             yield (model_trace, log_weight)
     52 

/usr/local/lib/python3.6/dist-packages/pyro/poutine/trace_struct.py in log_prob_sum(self, site_filter)
    189                 else:
    190                     try:
--> 191                         log_p = site["fn"].log_prob(site["value"], *site["args"], **site["kwargs"])
    192                     except ValueError:
    193                         _, exc_value, traceback = sys.exc_info()

/usr/local/lib/python3.6/dist-packages/pyro/distributions/delta.py in log_prob(self, x)
     58     def log_prob(self, x):
     59         v = self.v.expand(self.shape())
---> 60         log_prob = (x == v).type(x.dtype).log()
     61         log_prob = sum_rightmost(log_prob, self.event_dim)
     62         return log_prob + self.log_density

RuntimeError: log_vml_cpu not implemented for 'Long' 

Here is the code of my SCM:

from pyro.infer.importance import Importance
from pyro.infer.mcmc import MCMC
from pyro.infer.mcmc.nuts import HMC


intervened_model = pyro.do(SCM, data={"Y_shape": torch.tensor(1)})
conditioned_model = pyro.condition(intervened_model, data={
                                       "X": recon_x1, 
                                       "Y_shape": torch.tensor(0),
                                       "Z":z1})

#kernel = HMC(conditioned_model, step_size=0.8, num_steps=4)
#posterior = MCMC(kernel, num_samples=1000, warmup_steps=50)

posterior = pyro.infer.Importance(conditioned_model, num_samples = 1).run(vae, mu, sigma)
#posterior.run(vae, mu, sigma)

marginal = posterior.EmpiricalMarginal(posterior, )

print(type(posterior))
print(posterior)

result = []
for i in range(10):
  trace = posterior()
  x = trace.nodes['Nx']['value']
  y = trace.nodes['Ny']['value']
  z = trace.nodes['Nz']['value']
  con_obj = pyro.condition(intervened_model, data = {"Nx": x,"Ny": y, "Nz": z})
#   result.append(con_obj()[2])
  
# recon_x2,y2,z2 = con_obj(vae, mu, sigma)
# print(y2)
# recon_check(recon_x1.reshape(-1, 64, 64)[0], recon_x2.reshape(-1, 64, 64)[0])

Please let me know how to debug this or what the issue is. Highly appreciated

It seems to me that somewhere in your model code, you are using a Delta distribution that takes in a LongTensor instead of a FloatTensor. The fix might just be to change that. e.g. dist.Delta(torch.tensor(1)) -> dist.Delta(torch.tensor(1.0)).

We have accounted for the delta distribution by passing a float tensor. It’s throwing another error - “log_vml_cpu not implemented for ‘Byte’”.

This is how our data loader look like, hopefully this will give you some insight.
dataset_zip = np.load(
‘dsprites-dataset/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz’,
encoding = ‘bytes’,
allow_pickle=True
)

RuntimeError Traceback (most recent call last)
in ()
11
12 # posterior = pyro.infer.Importance(conditioned_model, num_samples = 1)
—> 13 posterior.run(vae, mu, sigma)
14 print(type(posterior))
15 print(posterior)

11 frames
/usr/local/lib/python3.6/dist-packages/pyro/infer/abstract_infer.py in run(self, *args, **kwargs)
    222         self._reset()
    223         with poutine.block():
--> 224             for i, vals in enumerate(self._traces(*args, **kwargs)):
    225                 if len(vals) == 2:
    226                     chain_id = 0

/usr/local/lib/python3.6/dist-packages/pyro/infer/mcmc/mcmc.py in _traces(self, *args, **kwargs)
    278 
    279     def _traces(self, *args, **kwargs):
--> 280         for sample in self.sampler._traces(*args, **kwargs):
    281             yield sample
    282 

/usr/local/lib/python3.6/dist-packages/pyro/infer/mcmc/mcmc.py in _traces(self, *args, **kwargs)
    212             progress_bar = ProgressBar(self.warmup_steps, self.num_samples, disable=self.disable_progbar)
    213         self.logger = initialize_logger(self.logger, logger_id, progress_bar, log_queue)
--> 214         self.kernel.setup(self.warmup_steps, *args, **kwargs)
    215         params = self.kernel.initial_params
    216         with optional(progress_bar, not is_multiprocessing):

/usr/local/lib/python3.6/dist-packages/pyro/infer/mcmc/hmc.py in setup(self, warmup_steps, *args, **kwargs)
    260         self._warmup_steps = warmup_steps
    261         if self.model is not None:
--> 262             self._initialize_model_properties(args, kwargs)
    263         potential_energy = self.potential_fn(self.initial_params)
    264         self._cache(self.initial_params, potential_energy, None)

/usr/local/lib/python3.6/dist-packages/pyro/infer/mcmc/hmc.py in _initialize_model_properties(self, model_args, model_kwargs)
    233             jit_compile=self._jit_compile,
    234             jit_options=self._jit_options,
--> 235             skip_jit_warnings=self._ignore_jit_warnings,
    236         )
    237         self.potential_fn = potential_fn

/usr/local/lib/python3.6/dist-packages/pyro/infer/mcmc/util.py in initialize_model(model, model_args, model_kwargs, transforms, max_plate_nesting, jit_compile, jit_options, skip_jit_warnings, num_chains)
    392     # enable potential_fn to be picklable (a torch._C.Function cannot be pickled).
    393     init_params = _get_init_params(model, model_args, model_kwargs, transforms,
--> 394                                    pe_maker.get_potential_fn(), prototype_params, num_chains=num_chains)
    395     potential_fn = pe_maker.get_potential_fn(jit_compile, skip_jit_warnings, jit_options)
    396     return init_params, potential_fn, transforms, model_trace

/usr/local/lib/python3.6/dist-packages/pyro/infer/mcmc/util.py in _get_init_params(model, model_args, model_kwargs, transforms, potential_fn, prototype_params, max_tries_initial_params, num_chains, strategy)
    306                 samples = {name: trace.nodes[name]["value"].detach() for name in params}
    307                 params = {k: transforms[k](v) for k, v in samples.items()}
--> 308             pe_grad, pe = potential_grad(potential_fn, params)
    309 
    310             if torch.isfinite(pe) and all(map(torch.all, map(torch.isfinite, pe_grad.values()))):

/usr/local/lib/python3.6/dist-packages/pyro/ops/integrator.py in potential_grad(potential_fn, z)
     73     for node in z_nodes:
     74         node.requires_grad_(True)
---> 75     potential_energy = potential_fn(z)
     76     grads = grad(potential_energy, z_nodes)
     77     for node in z_nodes:

/usr/local/lib/python3.6/dist-packages/pyro/infer/mcmc/util.py in _potential_fn(self, params)
    255         model_trace = poutine.trace(cond_model).get_trace(*self.model_args,
    256                                                           **self.model_kwargs)
--> 257         log_joint = self.trace_prob_evaluator.log_prob(model_trace)
    258         for name, t in self.transforms.items():
    259             log_joint = log_joint - torch.sum(

/usr/local/lib/python3.6/dist-packages/pyro/infer/mcmc/util.py in log_prob(self, model_trace)
    216         """
    217         if not self.has_enumerable_sites:
--> 218             return model_trace.log_prob_sum()
    219         log_probs = self._get_log_factors(model_trace)
    220         with shared_intermediates() as cache:

/usr/local/lib/python3.6/dist-packages/pyro/poutine/trace_struct.py in log_prob_sum(self, site_filter)
    189                 else:
    190                     try:
--> 191                         log_p = site["fn"].log_prob(site["value"], *site["args"], **site["kwargs"])
    192                     except ValueError:
    193                         _, exc_value, traceback = sys.exc_info()

/usr/local/lib/python3.6/dist-packages/pyro/distributions/delta.py in log_prob(self, x)
     58     def log_prob(self, x):
     59         v = self.v.expand(self.shape())
---> 60         log_prob = (x == v).type(x.dtype).log()
     61         log_prob = sum_rightmost(log_prob, self.event_dim)
     62         return log_prob + self.log_density

RuntimeError: log_vml_cpu not implemented for 'Byte'

Have you converted your dataset to a torch tensor, and if so, what is the type of your data (dset.dtype)? Could you paste your full model code here? It is hard to debug by looking at the error trace because the issue is most likely in the data that you are passing in or your model.

Also, regardless of this specific issue, delta distributions will not work with HMC; so depending on what you are trying to do, you will need to look at alternatives.

This is how our code looks like along with our causal model:

mu, sigma = vae.encoder.forward(x,vae.remap_y(y))
mu = mu.cpu()
sigma = sigma.cpu()
recon_x1, y1, z1 = SCM(vae, mu, sigma)
print(recon_x1.type())
print(y1.type())
print(z1.type())

which prints:

torch.ByteTensor
torch.FloatTensor
torch.FloatTensor

Further our causal models looks like -

def SCM(vae, mu, sigma):
    z_dim = vae.z_dim
    Ny, Y, ys = [], [], []
    Nx = pyro.sample("Nx", dist.Uniform(torch.zeros(vae.image_dim), torch.ones(vae.image_dim)))
    Nz = pyro.sample("Nz", dist.Normal(torch.zeros(z_dim), torch.ones(z_dim)))
    m = torch.distributions.gumbel.Gumbel(torch.tensor(0.0), torch.tensor(1.0))
    for label_id in range(6):
        name = vae.label_names[label_id]
        length = vae.label_shape[label_id]
        new = pyro.sample("Ny_%s"%name, dist.Uniform(torch.zeros(length), torch.ones(length)) )
        Ny.append(new)
        gumbel_vars = torch.tensor([m.sample() for _ in range(length)])
        max_ind = torch.argmax(torch.log(new) + gumbel_vars).item()
        print(torch.tensor(max_ind*1.0).type())
        Y.append(pyro.sample("Y_%s"%name, dist.Delta(torch.tensor(max_ind*1.0))))
        ys.append(torch.nn.functional.one_hot(torch.tensor(max_ind), int(length)))  
    Y = torch.tensor(Y)
    ys = torch.cat(ys).to(torch.float32).reshape(1,-1).cuda()
    Z = pyro.sample("Z", dist.Delta(mu + Nz*sigma))
    zs = Z.cuda()
    p = vae.decoder.forward(zs,ys)
    X = pyro.sample("X", dist.Delta(Nx < p.cpu()))
    return X, Y, Z

I see. I think the issue is in dist.Delta(Nx < p.cpu()). You probably want it to be something like dist.Delta((Nx < p).type(torch.float)). As I mentioned earlier, this will still give you problems if you are going to run HMC.

2 Likes

Thank you for helping us out. This worked.

Yes, this is still under construction. We’ll change our inference algorithm. Thank you for the feedback.