I’m fairly sure this is a bug in some validation code. The following mixture of Gaussians evaluates just fine:
import torch
import pyro
import pyro.distributions as dist
N = 10 # data points
P = 5 # dimensions
K = 2 # components
# simulate data
y = torch.rand([N,P]) # N points in P dimensions
print("y shape", y.shape)
mix = dist.Categorical(torch.rand(K)).expand([N])
#mix = dist.Categorical(torch.rand([N,K]))
print("mix shapes", mix.batch_shape, mix.event_shape) # [J],[]
comp_batch = dist.Normal(torch.randn(N, K, P),1.)
print("comp_batch shapes:",comp_batch.batch_shape, comp_batch.event_shape) # [N x K x P], []
comp = comp_batch.to_event(1)
#comp = dist.Independent(comp_batch, reinterpreted_batch_ndims = 1) # equivalent
print("comp shapes:", comp.batch_shape, comp.event_shape) # [N x K], [P]
mixture_batch = dist.MixtureSameFamily(mix, comp)
print("mixture_batch shapes:", mixture_batch.batch_shape, mixture_batch.event_shape) # [N], [P]
mixture = mixture_batch.to_event(1)
print("mixture shapes:", mixture.batch_shape, mixture.event_shape) # [], [N x P]
print(mixture.log_prob(y))
But the equivalent for a mixture of Binomials fails:
pyro.enable_validation(True)
# simulate data
total_count = 100
y = dist.Binomial(total_count, torch.full([N,P],0.5)).sample()
print("y shape", y.shape)
mix = dist.Categorical(torch.rand(K)).expand([N])
#mix = dist.Categorical(torch.rand([N,K])) # same error
print("mix shapes", mix.batch_shape, mix.event_shape) # [J],[]
comp_batch = dist.Binomial(total_count, torch.rand(N, K, P))
print("comp_batch shapes:",comp_batch.batch_shape, comp_batch.event_shape) # [N x K x P], []
#comp = comp_batch.to_event(1)
comp = dist.Independent(comp_batch, reinterpreted_batch_ndims = 1) # equivalent
print("comp shapes:", comp.batch_shape, comp.event_shape) # [N x K], [P]
mixture_batch = dist.MixtureSameFamily(mix, comp)
print("mixture_batch shapes:", mixture_batch.batch_shape, mixture_batch.event_shape) # [N], [P]
mixture = mixture_batch.to_event(1)
print("mixture shapes:", mixture.batch_shape, mixture.event_shape) # [], [N x P]
print(mixture.log_prob(y))
gives
.../torch/distributions/constraints.py in check(self, value)
297 def check(self, value):
298 return (
--> 299 (value % 1 == 0) & (self.lower_bound <= value) & (value <= self.upper_bound)
300 )
301
RuntimeError: The size of tensor a (10) must match the size of tensor b (2) at non-singleton dimension 1
Turning off validation it runs.