Simple GMM in Pyro

Hi everyone I’ve followed the tutorial “SVI Part I” with no trouble.
I tried to take a step further and add a bit of complexity.
That’s the set up of the problem:
I have one Bernoulli distribution with parameter “phi” from which I sample a latent variable “z”.
Then I have 2 Gaussians parameterized respectively with (mu-0, std-0) and (mu-1, std-1).
The generative process sample z from the Bernoulli and than if (z == 0) it samples x from the Gaussian-0 else it samples x from the Gaussian-1.
I want to infer phi, (mu-0, std-0) and (mu-1, std-1) from the observed x.
This is just a simple GMM that I’ve easily solved with the EM algorithm but now I’m a bit confused on how to implement this in Pyro and solve it with Variational Inference.
My current implementation doesn’t seem to work.

My code:

import pyro
import pyro.optim as optim
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO
import torch
import torch.distributions.constraints as constraints


def model(data):
    # priors
    phi = pyro.sample('phi', dist.Beta(torch.tensor(10.0), torch.tensor(10.0)))
    mu_0 = pyro.sample('mu-0', dist.Normal(torch.tensor(2.0), torch.tensor(1.0)))
    mu_1 = pyro.sample('mu-1', dist.Normal(torch.tensor(5.0), torch.tensor(1.0)))
    std_0 = pyro.sample('std-0', dist.HalfCauchy(scale=torch.tensor(1.)))
    std_1 = pyro.sample('std-1', dist.HalfCauchy(scale=torch.tensor(1.)))

    for i in range(len(data)):
        z = pyro.sample(f'z-{i}', dist.Bernoulli(phi))
        if z.long().item() == 0:
            pyro.sample(f'obs-{i}', dist.Normal(mu_0, std_0), obs=data[i])
        else:
            pyro.sample(f'obs-{i}', dist.Normal(mu_1, std_1), obs=data[i])

def guide(data):
    phi_alpha = pyro.param('phi-alpha', torch.tensor(10.0))
    phi_beta = pyro.param('phi-beta', torch.tensor(10.0))

    mu_0_mu = pyro.param('mu-0-mu', torch.tensor(2.0))
    mu_0_std = pyro.param('mu-0-std', torch.tensor(1.0), constraint=constraints.positive)

    mu_1_mu = pyro.param('mu-1-mu', torch.tensor(5.0))
    mu_1_std = pyro.param('mu-1-std', torch.tensor(1.0), constraint=constraints.positive)

    std_0_std = pyro.param('std-0-std', torch.tensor(1.0), constraint=constraints.positive)
    std_1_std = pyro.param('std-1-std', torch.tensor(1.0), constraint=constraints.positive)

    pyro.sample('phi', dist.Beta(phi_alpha, phi_beta))
    pyro.sample('mu-0', dist.Normal(mu_0_mu, mu_0_std))
    pyro.sample('mu-1', dist.Normal(mu_1_mu, mu_1_std))
    pyro.sample('std-0', dist.HalfCauchy(scale=std_0_std))
    pyro.sample('std-1', dist.HalfCauchy(scale=std_1_std))


# Data Generating Process #
z_dist = torch.distributions.Bernoulli(torch.tensor(0.75))
x_dists = [torch.distributions.Normal(torch.tensor(2.), torch.tensor(1.)),
           torch.distributions.Normal(torch.tensor(5.), torch.tensor(1.8)),]
z_sample = z_dist.sample((100,))
x_sample = [x_dists[int(z)].sample() for z in z_sample]


pyro.clear_param_store()
pyro.enable_validation(True)

svi = SVI(model, guide,
          optim=optim.ClippedAdam({'lr': 0.01}),
          loss=Trace_ELBO())

c = 0
for step in range(1000):
    c += 1
    loss = svi.step(x_sample)
    if step % 100 == 0:

        phi_alpha = pyro.param('phi-alpha').item()
        phi_beta = pyro.param('phi-beta').item()
        mu_0_mu = pyro.param('mu-0-mu').item()
        mu_1_mu = pyro.param('mu-1-mu').item()

        phi = phi_alpha / (phi_alpha + phi_beta)

        print("[iteration {:>4}] loss: {:.4f} | phi: {:.2f}, mu-0: {:.2f}, mu-1: {:.2f}".format(c, loss, phi, mu_0_mu, mu_1_mu))

Also I get this warning:

/home/fabio/miniconda3/envs/Variational-Inference_pytorch/lib/python3.9/site-packages/pyro/util.py:244: UserWarning: Found vars in model but not guide: {'z-96', 'z-43', 'z-33', 'z-60', 'z-42', 'z-7', 'z-91', 'z-51', 'z-47', 'z-49', 'z-24', 'z-88', 'z-10', 'z-2', 'z-67', 'z-77', 'z-83', 'z-23', 'z-62', 'z-97', 'z-35', 'z-20', 'z-4', 'z-8', 'z-27', 'z-30', 'z-34', 'z-85', 'z-82', 'z-39', 'z-45', 'z-32', 'z-66', 'z-74', 'z-78', 'z-71', 'z-89', 'z-29', 'z-11', 'z-76', 'z-18', 'z-81', 'z-5', 'z-99', 'z-61', 'z-57', 'z-95', 'z-68', 'z-50', 'z-25', 'z-37', 'z-93', 'z-16', 'z-48', 'z-9', 'z-53', 'z-65', 'z-40', 'z-44', 'z-13', 'z-14', 'z-12', 'z-72', 'z-75', 'z-64', 'z-31', 'z-87', 'z-19', 'z-98', 'z-94', 'z-46', 'z-63', 'z-90', 'z-6', 'z-52', 'z-3', 'z-22', 'z-17', 'z-59', 'z-21', 'z-84', 'z-92', 'z-28', 'z-79', 'z-26', 'z-36', 'z-70', 'z-38', 'z-55', 'z-69', 'z-1', 'z-80', 'z-58', 'z-73', 'z-0', 'z-56', 'z-86', 'z-54', 'z-41', 'z-15'}
  warnings.warn("Found vars in model but not guide: {}".format(model_vars - guide_vars - enum_vars))

Hi @Cesch, have you taken a look at the Gaussian Mixture Model tutorial?

Hi @fritzo thank you for answering.
Yeah I looked at the GMM tutorial but I would like to hand code the guide function and not use AutoDelta

In the meantime I tried to improve my implementation adding pyro.plate but now I get an error that I don’t understand

def model(data):
    # priors
    phi = torch.tensor(0.75)
    mu = pyro.sample('mu', dist.Normal(torch.tensor([2., 5.]), torch.tensor([1., 1.])))
    std = pyro.sample('std', dist.Normal(torch.tensor([1., 1.8]), torch.tensor([0.3, 0.3])))

    with pyro.plate('obs', len(data)):
        z = pyro.sample('z', dist.Bernoulli(phi))
        pyro.sample('x', dist.Normal(mu[z.long()], std[z.long()]), obs=data)


def guide(data):
    phi = pyro.param('phi', torch.tensor(0.75), constraint=constraints.unit_interval)
    mu_mu = pyro.param('mu-mu', torch.tensor([2., 5.]))
    mu_std = pyro.param('mu-std', torch.tensor([1., 1.]), constraint=constraints.positive)
    std_mu = pyro.param('std-mu', torch.tensor([1., 1.]), constraint=constraints.positive)
    std_std = pyro.param('std-std', torch.tensor([0.3, 0.3]), constraint=constraints.positive)

    pyro.sample('mu', dist.Normal(mu_mu, mu_std))
    pyro.sample('std', dist.Normal(std_mu, std_std))

    with pyro.plate('z-sample', len(data)):
        pyro.sample('z', dist.Bernoulli(phi))

The error is about the shape of the parameter mu

ValueError: at site "mu", invalid log_prob shape
  Expected [], actual [2]
  Try one of the following fixes:
  - enclose the batched tensor in a with pyro.plate(...): context
  - .to_event(...) the distribution being sampled
  - .permute() data dimensions

Here are a couple fixes:

  1. As described in the tensor shapes tutorial, you’ll need to call .to_event(1) to use scalar distributions like Normal as a joint diagonal distributions over multiple variables:
- pyro.sample('mu', dist.Normal(mu_mu, mu_std))
+ pyro.sample('mu', dist.Normal(mu_mu, mu_std).to_event(1))
  1. You’ll want a positive-supported distribution for the std variable, e.g. LogNormal:
- pyro.sample('std', dist.Normal(std_mu, std_std))
+ pyro.sample('std', dist.LogNormal(std_mu, std_std).to_event(1))

Thank you @fritzo for your help, now the code runs smoothly but It doesn’t converge.
I thought this was a quite simple problem so I could have completely misunderstood how to run variational inference or this is just the wrong way to solve this particular problem (?)

That’s the final working (not really) code:

import pyro
import pyro.optim as optim
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO
import torch
import torch.distributions.constraints as constraints


locs_range = (0., 10.)


def model(data):
    # priors
    phi = torch.rand(1)[0]
    locs_mu = torch.rand(2) * (locs_range[1] - locs_range[0]) + locs_range[0]
    locs = pyro.sample('locs', dist.Normal(locs_mu, torch.tensor([1., 1.])).to_event(1))
    scales = pyro.sample('scales', dist.LogNormal(torch.tensor([0., 0.]), torch.tensor([0.25, 0.25])).to_event(1))

    with pyro.plate('z-sample', len(data)):
        z = pyro.sample('z', dist.Bernoulli(phi))

    with pyro.plate('obs', len(data)):
        pyro.sample('x', dist.Normal(locs[z.long()], scales[z.long()]), obs=data)


def guide(data):
    phi = pyro.param('phi', torch.rand(1)[0], constraint=constraints.unit_interval)
    locs_mu = pyro.param('locs-mu', torch.rand(2) * (locs_range[1] - locs_range[0]) + locs_range[0])
    locs_std = pyro.param('locs-std', torch.tensor([1., 1.]), constraint=constraints.positive)
    scales_mu = pyro.param('scales-mu', torch.tensor([0.0, 0.0]))
    scales_std = pyro.param('scales-std', torch.tensor([0.25, 0.25]), constraint=constraints.positive)

    pyro.sample('locs', dist.Normal(locs_mu, locs_std).to_event(1))
    pyro.sample('scales', dist.LogNormal(scales_mu, scales_std).to_event(1))

    with pyro.plate('z-sample', len(data)):
        pyro.sample('z', dist.Bernoulli(phi))


# Data Generating Process #

# parameters (to find)
g_phi = torch.tensor(0.75)
g_locs = torch.tensor([2., 5.])
g_scales = torch.tensor([1., 1.8])
#
k = g_locs.size(0)
z_dist = torch.distributions.Bernoulli(torch.tensor(0.75))
x_dists = [torch.distributions.Normal(g_locs[i], g_scales[i]) for i in range(k)]
z_sample = z_dist.sample((256,))
data = torch.cat([x_dists[int(z)].sample().unsqueeze(0) for z in z_sample])


# Variational Inference

pyro.clear_param_store()
pyro.enable_validation(True)

svi = SVI(model, guide,
          optim=optim.Adam({'lr': 0.01}),
          loss=Trace_ELBO())

c = 0
for step in range(10000):
    c += 1
    loss = svi.step(data)
    if step % 500 == 0:

        phi = pyro.param('phi')
        locs = pyro.param('locs-mu')
        scales_mu = pyro.param('scales-mu')
        scales_std = pyro.param('scales-std')
        with torch.no_grad():
            scales = torch.exp(scales_mu + (0.5 * scales_std**2))

        print("[iteration {:>4}] loss: {:.4f} | phi: {:.2f}, locs: {}, scales: {}".format(c, loss, phi, locs, scales))

And this is the output:

[iteration    1] loss: 2579.5171 | phi: 0.79, locs: tensor([6.8785, 8.2995], requires_grad=True), scales: tensor([1.0428, 1.0415])
[iteration  501] loss: 695.8669 | phi: 0.76, locs: tensor([5.0579, 6.4874], requires_grad=True), scales: tensor([2.3059, 3.1808])
[iteration 1001] loss: 676.5963 | phi: 0.75, locs: tensor([4.3049, 5.3582], requires_grad=True), scales: tensor([2.1125, 2.7267])
[iteration 1501] loss: 1040.1174 | phi: 0.75, locs: tensor([4.2726, 4.4476], requires_grad=True), scales: tensor([2.0634, 2.1994])
[iteration 2001] loss: 926.0440 | phi: 0.78, locs: tensor([4.2659, 4.2802], requires_grad=True), scales: tensor([2.0191, 2.1668])
[iteration 2501] loss: 774.7163 | phi: 0.69, locs: tensor([4.2719, 4.2416], requires_grad=True), scales: tensor([2.0512, 2.1277])
[iteration 3001] loss: 588.2146 | phi: 0.55, locs: tensor([4.2787, 4.2594], requires_grad=True), scales: tensor([2.0919, 2.1071])
[iteration 3501] loss: 708.5364 | phi: 0.56, locs: tensor([4.3014, 4.1924], requires_grad=True), scales: tensor([2.0750, 2.0640])
[iteration 4001] loss: 579.1063 | phi: 0.57, locs: tensor([4.2625, 4.2291], requires_grad=True), scales: tensor([2.0844, 2.0906])
[iteration 4501] loss: 621.2787 | phi: 0.46, locs: tensor([4.2265, 4.2475], requires_grad=True), scales: tensor([2.0817, 2.0447])
[iteration 5001] loss: 588.2329 | phi: 0.39, locs: tensor([4.2044, 4.2652], requires_grad=True), scales: tensor([2.0716, 2.0644])
[iteration 5501] loss: 787.8038 | phi: 0.42, locs: tensor([4.1833, 4.2317], requires_grad=True), scales: tensor([2.0564, 2.0568])
[iteration 6001] loss: 637.1835 | phi: 0.42, locs: tensor([4.2204, 4.2431], requires_grad=True), scales: tensor([2.0783, 2.0910])
[iteration 6501] loss: 936.9914 | phi: 0.43, locs: tensor([4.2396, 4.2416], requires_grad=True), scales: tensor([2.0582, 2.0589])
[iteration 7001] loss: 659.4283 | phi: 0.32, locs: tensor([4.2295, 4.2193], requires_grad=True), scales: tensor([2.0204, 2.0766])
[iteration 7501] loss: 589.2785 | phi: 0.32, locs: tensor([4.2671, 4.2057], requires_grad=True), scales: tensor([2.1348, 2.0374])
[iteration 8001] loss: 585.1319 | phi: 0.44, locs: tensor([4.1966, 4.2963], requires_grad=True), scales: tensor([2.1289, 2.0877])
[iteration 8501] loss: 642.1887 | phi: 0.64, locs: tensor([4.2335, 4.2294], requires_grad=True), scales: tensor([1.9951, 2.1458])
[iteration 9001] loss: 846.7444 | phi: 0.61, locs: tensor([4.1834, 4.1928], requires_grad=True), scales: tensor([2.0519, 2.0754])
[iteration 9501] loss: 669.9127 | phi: 0.61, locs: tensor([4.2573, 4.2074], requires_grad=True), scales: tensor([1.9930, 2.0970])

i believe you’ll get better results if you put the x/z sample statements in a single plate

1 Like

Hi @martinjankowiak, I tried to include x/z sample in a single plate but I didn’t get any improvement.
I’ve tried to use the @config_enumerate and loss TraceEnum_ELBO(max_plate_nesting=1) as suggested in the the Enumeration tutorial but it seems to converge to a single mode like I saw in an old issue https://github.com/pyro-ppl/pyro/issues/635

locs_range = (0., 10.)

@config_enumerate
def model(data):
    phi = torch.tensor([0.75])
    locs = pyro.param('locs', torch.randn(2) * (locs_range[1] - locs_range[0]) + locs_range[0])
    scales = pyro.param('scales', torch.tensor([1., 1.]), constraint=constraints.positive)

    with pyro.plate('obs', len(data)):
        z = pyro.sample('z', dist.Bernoulli(phi)).to(torch.int64)
        pyro.sample('x', dist.Normal(locs[z], scales[z]), obs=data)

@config_enumerate
def guide(data):
    phi = pyro.param('phi', torch.tensor([0.75]), constraint=constraints.unit_interval)
    with pyro.plate('obs', len(data)):
        pyro.sample('z', dist.Bernoulli(phi))

This is the output

[iteration    1] loss: 13420.3818 | phi: tensor([0.7308], grad_fn=<ClampBackward1>), locs: tensor([ 0.0163, -6.6823], requires_grad=True), scales: tensor([1.1052, 1.1052], grad_fn=<AddBackward0>)
[iteration  501] loss: 744.1896 | phi: tensor([0.4958], grad_fn=<ClampBackward1>), locs: tensor([ 4.3569, -1.7495], requires_grad=True), scales: tensor([2.2398, 7.1233], grad_fn=<AddBackward0>)
[iteration 1001] loss: 569.6798 | phi: tensor([0.7500], grad_fn=<ClampBackward1>), locs: tensor([4.3569, 4.3569], requires_grad=True), scales: tensor([2.2398, 2.2398], grad_fn=<AddBackward0>)
[iteration 1501] loss: 569.6797 | phi: tensor([0.7500], grad_fn=<ClampBackward1>), locs: tensor([4.3569, 4.3569], requires_grad=True), scales: tensor([2.2398, 2.2398], grad_fn=<AddBackward0>)
[iteration 2001] loss: 569.6797 | phi: tensor([0.7500], grad_fn=<ClampBackward1>), locs: tensor([4.3569, 4.3569], requires_grad=True), scales: tensor([2.2398, 2.2398], grad_fn=<AddBackward0>)
[iteration 2501] loss: 569.6797 | phi: tensor([0.7500], grad_fn=<ClampBackward1>), locs: tensor([4.3569, 4.3569], requires_grad=True), scales: tensor([2.2398, 2.2398], grad_fn=<AddBackward0>)
[iteration 3001] loss: 569.6797 | phi: tensor([0.7500], grad_fn=<ClampBackward1>), locs: tensor([4.3569, 4.3569], requires_grad=True), scales: tensor([2.2398, 2.2398], grad_fn=<AddBackward0>)
[iteration 3501] loss: 569.6797 | phi: tensor([0.7500], grad_fn=<ClampBackward1>), locs: tensor([4.3569, 4.3569], requires_grad=True), scales: tensor([2.2398, 2.2398], grad_fn=<AddBackward0>)
[iteration 4001] loss: 569.6797 | phi: tensor([0.7500], grad_fn=<ClampBackward1>), locs: tensor([4.3569, 4.3569], requires_grad=True), scales: tensor([2.2398, 2.2398], grad_fn=<AddBackward0>)
[iteration 4501] loss: 569.6797 | phi: tensor([0.7500], grad_fn=<ClampBackward1>), locs: tensor([4.3569, 4.3569], requires_grad=True), scales: tensor([2.2398, 2.2398], grad_fn=<AddBackward0>)

I’ve tried to declare all the parameters in the model and omit the guide function and surprisingly I got a better convergence

locs_range = (0., 10.)

@config_enumerate
def model(data):
    phi = pyro.param('phi', torch.tensor([0.5]), constraint=constraints.unit_interval)
    locs = pyro.param('locs', torch.rand(2) * (locs_range[1] - locs_range[0]) + locs_range[0])
    scales = pyro.param('scales', torch.tensor([1., 1.]), constraint=constraints.positive)

    with pyro.plate('obs', len(data)):
        z = pyro.sample('z', dist.Bernoulli(phi)).to(torch.int64)
        pyro.sample('x', dist.Normal(locs[z], scales[z]), obs=data)

def guide(data):
    pass

# ... #

svi = SVI(model, guide,
          optim=optim.Adam({'lr': 0.1}),
          loss=TraceEnum_ELBO(max_plate_nesting=1))

The output

[iteration    1] loss: 1188.7950 | phi: tensor([0.5250], grad_fn=<ClampBackward1>), locs: tensor([0.7947, 2.7892], requires_grad=True), scales: tensor([1.1052, 1.1052], grad_fn=<AddBackward0>)
[iteration  501] loss: 544.1373 | phi: tensor([0.7115], grad_fn=<ClampBackward1>), locs: tensor([1.7717, 5.1969], requires_grad=True), scales: tensor([1.0235, 1.5323], grad_fn=<AddBackward0>)
[iteration 1001] loss: 544.1373 | phi: tensor([0.7110], grad_fn=<ClampBackward1>), locs: tensor([1.7735, 5.1986], requires_grad=True), scales: tensor([1.0244, 1.5313], grad_fn=<AddBackward0>)
[iteration 1501] loss: 544.1373 | phi: tensor([0.7110], grad_fn=<ClampBackward1>), locs: tensor([1.7735, 5.1986], requires_grad=True), scales: tensor([1.0244, 1.5313], grad_fn=<AddBackward0>)
[iteration 2001] loss: 544.1373 | phi: tensor([0.7110], grad_fn=<ClampBackward1>), locs: tensor([1.7735, 5.1986], requires_grad=True), scales: tensor([1.0244, 1.5313], grad_fn=<AddBackward0>)
[iteration 2501] loss: 544.1373 | phi: tensor([0.7110], grad_fn=<ClampBackward1>), locs: tensor([1.7735, 5.1986], requires_grad=True), scales: tensor([1.0244, 1.5313], grad_fn=<AddBackward0>)
[iteration 3001] loss: 544.1373 | phi: tensor([0.7110], grad_fn=<ClampBackward1>), locs: tensor([1.7735, 5.1986], requires_grad=True), scales: tensor([1.0244, 1.5313], grad_fn=<AddBackward0>)
[iteration 3501] loss: 544.1373 | phi: tensor([0.7110], grad_fn=<ClampBackward1>), locs: tensor([1.7735, 5.1986], requires_grad=True), scales: tensor([1.0244, 1.5313], grad_fn=<AddBackward0>)
[iteration 4001] loss: 544.1373 | phi: tensor([0.7110], grad_fn=<ClampBackward1>), locs: tensor([1.7735, 5.1986], requires_grad=True), scales: tensor([1.0244, 1.5313], grad_fn=<AddBackward0>)
[iteration 4501] loss: 544.1374 | phi: tensor([0.7110], grad_fn=<ClampBackward1>), locs: tensor([1.7735, 5.1986], requires_grad=True), scales: tensor([1.0243, 1.5313], grad_fn=<AddBackward0>)

This gives me a point estimate of the parameters but I would like to model them as distributions.
The problem is that if I try to include the parameters in the guide function I get an error about a missing dimension in torch.tensordot, I get the same error even trying to run the tutorials Gaussian Mixture Model and toy_mixture_model_discrete_enumeration and recently I found this issue about torch.tensordot which seem related to my case

locs_range = (0., 10.)

@config_enumerate
def model(data):
    # priors
    weights = pyro.sample('weights', dist.Dirichlet(torch.tensor([5., 5.])))
    locs_mu = torch.rand(2) * (locs_range[1] - locs_range[0]) + locs_range[0]
    locs = pyro.sample('locs', dist.Normal(locs_mu, torch.tensor([1., 1.])).to_event(1))
    scales = pyro.sample('scales', dist.LogNormal(torch.tensor([0., 0.]),
                                                  torch.tensor([0.25, 0.25])).to_event(1))

    with pyro.plate('obs', len(data)):
        z = pyro.sample('z', dist.Categorical(weights), infer={"enumerate": "sequential"})
        pyro.sample('x', dist.Normal(locs[z], scales[z]))


def guide(data):
    # var params
    weights_probs = pyro.param('weights_probs', torch.tensor([5., 5.]))

    locs_mu = pyro.param('locs-mu', torch.rand(2) * (locs_range[1] - locs_range[0]) + locs_range[0])
    locs_std = pyro.param('locs-std', torch.tensor([1., 1.]), constraint=constraints.positive)

    scales_mu = pyro.param('scales-mu', torch.tensor([0.0, 0.0]))
    scales_std = pyro.param('scales-std', torch.tensor([0.25, 0.25]), constraint=constraints.positive)

    pyro.sample('weights', dist.Dirichlet(weights_probs))
    pyro.sample('locs', dist.Normal(locs_mu, locs_std).to_event(1))
    pyro.sample('scales', dist.LogNormal(scales_mu, scales_std).to_event(1))

The error:

Traceback (most recent call last):
  File "/home/fabio/Python-projects/pythonProject/Variational-Inference_pytorch/GMM.py", line 68, in <module>
    loss = svi.step(data)
  File "/home/fabio/miniconda3/envs/Variational-Inference_pytorch/lib/python3.9/site-packages/pyro/infer/svi.py", line 128, in step
    loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
  File "/home/fabio/miniconda3/envs/Variational-Inference_pytorch/lib/python3.9/site-packages/pyro/infer/traceenum_elbo.py", line 405, in loss_and_grads
    elbo_particle = _compute_dice_elbo(model_trace, guide_trace)
  File "/home/fabio/miniconda3/envs/Variational-Inference_pytorch/lib/python3.9/site-packages/pyro/infer/traceenum_elbo.py", line 185, in _compute_dice_elbo
    return Dice(guide_trace, ordering).compute_expectation(costs)
  File "/home/fabio/miniconda3/envs/Variational-Inference_pytorch/lib/python3.9/site-packages/pyro/infer/util.py", line 303, in compute_expectation
    expected_cost = expected_cost + scale * torch.tensordot(prob, cost, prob.dim())
  File "/home/fabio/miniconda3/envs/Variational-Inference_pytorch/lib/python3.9/site-packages/torch/functional.py", line 929, in tensordot
    raise RuntimeError(f"unsupported input to tensordot, got dims={dims}")
RuntimeError: unsupported input to tensordot, got dims=0

What version of PyTorch are you using? You may just have to go down a version to avoid the error.

This is the configuration that gives me the tensordot error:

  • python - 3.9
  • pytorch - 1.9.0
  • pyro - 1.6.0

I downgraded to pytorch - 1.8.1 and now everything works fine.

Sorry, we’re working on a fix to support PyTorch 1.9, ETA next wednesday (for a Pyro 1.7 release). Until then you can use the older PyTorch 1.8 or you can use Pyro’s pytorch-1.9 branch:

pip install git+https://github.com/pyro-ppl/pyro.git@pytorch-1.9