Dimension error with Gamma but not Uniform

I am fairly new to pyro and working out a few example problems for a project. The code below is just to replicate a difference I found between the Gamma and Uniform distributions. As far as I can tell the dimensions of the two distributions are the same. However, the Gamma runs without error while the Uniform model gives the error:

RuntimeError: The size of tensor a (3) must match the size of tensor b (1000) at non-singleton dimension 0

I can’t see why these would work differently. Grateful for any help here!

import torch
import pyro
import pyro.distributions as dist
import pyro.distributions.constraints as constraints

n_cat = torch.tensor(3)
n_bins = torch.tensor(3)
n_obs = torch.tensor(1000)

mix_probs = dist.Dirichlet(torch.tensor([10]*n_cat, dtype=torch.float32)).sample()
conc_sim = 1/torch.rand((n_cat, n_bins))
rate_sim = torch.rand((n_cat, n_bins)) * 20

gamma_mix = dist.MixtureSameFamily(dist.Categorical(mix_probs),
                                   dist.Gamma(conc_sim, rate_sim).to_event(1))

high_sim = 1/torch.rand(n_cat, n_bins)
low_sim = torch.full((n_cat, n_bins), 0, dtype=torch.float32)
unif_mix = dist.MixtureSameFamily(dist.Categorical(mix_probs),
                                  dist.Uniform(low_sim, low_sim+high_sim).to_event(1))

samp_gamma = gamma_mix.sample((n_obs,))
samp_unif = unif_mix.sample((n_obs,))

def gamma_model(n_cat, n_bins, n_obs, sample):
    mix_probs = pyro.param("mix_probs", torch.ones(n_cat, dtype=float)/n_cat, constraint=constraints.simplex)
    conc = pyro.param("conc", 1/torch.rand(n_cat, n_bins), constraint=constraints.positive)
    rate = pyro.param("rate", torch.rand((n_cat, n_bins)) * 20, constraint=constraints.positive)
    with pyro.plate("observations", n_obs):
        return pyro.sample("obs", dist.MixtureSameFamily(dist.Categorical(mix_probs), dist.Gamma(conc, rate).to_event(1)), obs=sample)
        
guide = pyro.infer.autoguide.AutoDelta(gamma_model)
pyro.clear_param_store()
adam = pyro.optim.Adam({"lr":0.01})
elbo = pyro.infer.Trace_ELBO()
svi = pyro.infer.SVI(gamma_model, guide, adam, elbo)
for step in range(100):
    svi.step(n_cat, n_bins, n_obs, samp_gamma)
    
print(pyro.param("mix_probs"), pyro.param("conc"), pyro.param("rate"))
    
def unif_model(n_cat, n_bins, n_obs, sample):
    mix_probs = pyro.param("mix_probs", torch.ones(n_cat, dtype=float)/n_cat, constraint=constraints.simplex)
    high = pyro.param("high", high_sim, constraint=constraints.positive)
    low = pyro.param("low", low_sim, constraint=constraints.positive)
    with pyro.plate("observations", n_obs):
        pyro.sample("obs", dist.MixtureSameFamily(dist.Categorical(mix_probs), dist.Uniform(low, high+low).to_event(1)), obs=sample)
        
guide = pyro.infer.autoguide.AutoDelta(unif_model)
pyro.clear_param_store()
adam = pyro.optim.Adam({"lr":0.01})
elbo = pyro.infer.Trace_ELBO()
svi = pyro.infer.SVI(unif_model, guide, adam, elbo)
for step in range(100):
    svi.step(n_cat, n_bins, n_obs, samp_unif)
    
print(pyro.param("mix_probs"), pyro.param("high"), pyro.param("low"))

Hi @emkoch . This might be related to an upstream issue mentioned in the definition of the support of the MixtureSameFamily distribution:

    @constraints.dependent_property
    def support(self):
        # FIXME this may have the wrong shape when support contains batched
        # parameters
        return self._component_distribution.support

Hi @ordabayev , thanks for the response. I think this is exactly the problem. Unless I’m mistaken, there is no there is actually no implementation of the support calculation for a mixture of Uniform distributions, or maybe any distribution where the support varies among mixture components.

Aside from the batch dimension issue, torch uses the union when _validate_sample is called:

valid = support.check(value)
        if not valid.all():
            raise ValueError(
                "Expected value argument "
                f"({type(value).__name__} of shape {tuple(value.shape)}) "
                f"to be within the support ({repr(support)}) "
                f"of the distribution {repr(self)}, "
                f"but found invalid values:\n{value}"
            )

Trying to hack around this by setting validate_args=False did not work for me.

You can disable validation by setting pyro.enable_validation(False)