Function AddBackward0 returned an invalid gradient at index 1 - expected type torch.cuda.FloatTensor but got torch.FloatTensor

def semisupervised_loss_and_grads(model, guide, *args, **kwargs):

    alpha = ALPHA['alpha']


    batch_size = len(args[0])

    num_particle = NUM_PARTICLE

    elbo_u, elbo_s = 0., 0.

    guide_traces, model_traces = [], []

    for _ in range(num_particle):

        guide_trace = poutine.trace(guide).get_trace(*args, **kwargs)

        model_trace = poutine.trace(poutine.replay(model, trace=guide_trace)).get_trace(*args, **kwargs)



        model_sup_particle = model_trace.log_prob_sum(

            lambda name, site: site['type'] == 'sample' and name.find('sup') == 0)

        model_unsup_particle = model_trace.log_prob_sum(

            lambda name, site: site['type'] == 'sample' and name.find('unsup') == 0)

        guide_particle = guide_trace.log_prob_sum()

        elbo_u += (model_unsup_particle - guide_particle)

        elbo_s += model_sup_particle

    # Must multiply N/M <AKA len(unsupervised_dataloader)> to unsupervised ELBO

    # where N: data size and M: batch size

    batch_constant = len(unsupervised_dataloader)

    elbo = batch_constant / num_particle * (elbo_u + alpha * elbo_s)

    surrogate_theta_particle = 0.

    surrogate_phi_particle = 0.

    for model_trace, guide_trace in zip(model_traces, guide_traces):

        guide_particle = guide_trace.log_prob_sum()

        # Compute theta gradient expectation

        model_z_sup_particle = model_trace.log_prob_sum(

            lambda name, site: site['type'] == 'sample' and name.find('sup_first_name') == 0)

        model_z_unsup_particle = model_trace.log_prob_sum(

            lambda name, site: site['type'] == 'sample' and name.find('unsup_first_name') == 0)

        surrogate_theta_particle += model_z_unsup_particle + (alpha * model_z_sup_particle)

        # Compute phi gradient expectation

        surrogate_phi_particle += (elbo_u - 1).detach() * guide_particle

    # Scale the gradient functions by N/M and num_particle

    surrogate_theta_particle = -batch_constant / num_particle * surrogate_theta_particle

    surrogate_phi_particle = -batch_constant / num_particle * surrogate_phi_particle

    # Backprop on theta and phi gradient function



    return -elbo

I have this code, which samples from a OneHotCategorical distribution, but when I change it to a RelaxedOneHotCategoricalStraightThrough it breaks. What’s going on? Why would it have this error when it’s a RelaxedOneHotCateogoricalStraightThrough? The temp tensor for RelaxedOneHotCategoricalStraightThrough is a floattensor that is set to DEVICE, which is determined at the beginning depending on user’s computer. Could be CPU or CUDA

this happens at line “surrogate_phi_particle.backward(retain_graph=True)”

@DolanTheMFWizard I looked at the implementation of Relaxed… distribution but found no suspicious code. Could you try to create a reproducible code so we can trace down the issue?

1 Like

Sure I also found this caused the same issue

for index in range(MAX_NAME_LENGTH):

            char_dist, hidd_cell_states = lstm.forward(input_tensor, hidd_cell_states)

            torch.softmax(char_dist, dim=2)

            input_tensor = pyro.sample(f"sup_{address}_{index}",



            # Sampled char should be an index not a one-hot

            chars_at_indexes = list(

                map(lambda index: ALL_LETTERS[int(index.item())], torch.argmax(input_tensor, dim=2).squeeze(0)))

            # Add sampled characters to names

            for i, char in enumerate(chars_at_indexes):

                names[i] += char

without the torch.softmax this works, but with it, it throws the same error as above.

I have all my tensors with “.to(DEVICE)” where DEVICE is

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "gpu")

So the error appears if I change the above OneHotCategorical to RelaxedOneHotCateogoricalStraightThrough or have torch.softmax there

I also replaced all the “dist.OneHotCategoricals” with my custom made OneHotCategoricals that’s the same as the Pyro code

import torch

from pyro.distributions.torch_distribution import TorchDistributionMixin

class OneHotCategorical(torch.distributions.OneHotCategorical, TorchDistributionMixin):

    def expand(self, batch_shape):

        batch_shape = torch.Size(batch_shape)

        validate_args = self.__dict__.get('validate_args')

        if 'probs' in self.__dict__:

            probs = self.probs.expand(batch_shape + self.event_shape)

            return OneHotCategorical(probs=probs, validate_args=validate_args)


            logits = self.logits.expand(batch_shape + self.event_shape)

            return OneHotCategorical(logits=logits, validate_args=validate_args)

and I got the same error using this