I am fairly new to pyro and working out a few example problems for a project. The code below is just to replicate a difference I found between the Gamma and Uniform distributions. As far as I can tell the dimensions of the two distributions are the same. However, the Gamma runs without error while the Uniform model gives the error:
RuntimeError: The size of tensor a (3) must match the size of tensor b (1000) at non-singleton dimension 0
I can’t see why these would work differently. Grateful for any help here!
import torch
import pyro
import pyro.distributions as dist
import pyro.distributions.constraints as constraints
n_cat = torch.tensor(3)
n_bins = torch.tensor(3)
n_obs = torch.tensor(1000)
mix_probs = dist.Dirichlet(torch.tensor([10]*n_cat, dtype=torch.float32)).sample()
conc_sim = 1/torch.rand((n_cat, n_bins))
rate_sim = torch.rand((n_cat, n_bins)) * 20
gamma_mix = dist.MixtureSameFamily(dist.Categorical(mix_probs),
dist.Gamma(conc_sim, rate_sim).to_event(1))
high_sim = 1/torch.rand(n_cat, n_bins)
low_sim = torch.full((n_cat, n_bins), 0, dtype=torch.float32)
unif_mix = dist.MixtureSameFamily(dist.Categorical(mix_probs),
dist.Uniform(low_sim, low_sim+high_sim).to_event(1))
samp_gamma = gamma_mix.sample((n_obs,))
samp_unif = unif_mix.sample((n_obs,))
def gamma_model(n_cat, n_bins, n_obs, sample):
mix_probs = pyro.param("mix_probs", torch.ones(n_cat, dtype=float)/n_cat, constraint=constraints.simplex)
conc = pyro.param("conc", 1/torch.rand(n_cat, n_bins), constraint=constraints.positive)
rate = pyro.param("rate", torch.rand((n_cat, n_bins)) * 20, constraint=constraints.positive)
with pyro.plate("observations", n_obs):
return pyro.sample("obs", dist.MixtureSameFamily(dist.Categorical(mix_probs), dist.Gamma(conc, rate).to_event(1)), obs=sample)
guide = pyro.infer.autoguide.AutoDelta(gamma_model)
pyro.clear_param_store()
adam = pyro.optim.Adam({"lr":0.01})
elbo = pyro.infer.Trace_ELBO()
svi = pyro.infer.SVI(gamma_model, guide, adam, elbo)
for step in range(100):
svi.step(n_cat, n_bins, n_obs, samp_gamma)
print(pyro.param("mix_probs"), pyro.param("conc"), pyro.param("rate"))
def unif_model(n_cat, n_bins, n_obs, sample):
mix_probs = pyro.param("mix_probs", torch.ones(n_cat, dtype=float)/n_cat, constraint=constraints.simplex)
high = pyro.param("high", high_sim, constraint=constraints.positive)
low = pyro.param("low", low_sim, constraint=constraints.positive)
with pyro.plate("observations", n_obs):
pyro.sample("obs", dist.MixtureSameFamily(dist.Categorical(mix_probs), dist.Uniform(low, high+low).to_event(1)), obs=sample)
guide = pyro.infer.autoguide.AutoDelta(unif_model)
pyro.clear_param_store()
adam = pyro.optim.Adam({"lr":0.01})
elbo = pyro.infer.Trace_ELBO()
svi = pyro.infer.SVI(unif_model, guide, adam, elbo)
for step in range(100):
svi.step(n_cat, n_bins, n_obs, samp_unif)
print(pyro.param("mix_probs"), pyro.param("high"), pyro.param("low"))