Bug with batched mixture of Binomials

I’m fairly sure this is a bug in some validation code. The following mixture of Gaussians evaluates just fine:

import torch
import pyro
import pyro.distributions as dist

N = 10 # data points
P = 5 # dimensions
K = 2 # components

# simulate data
y = torch.rand([N,P]) # N points in P dimensions
print("y shape", y.shape)

mix = dist.Categorical(torch.rand(K)).expand([N])
#mix = dist.Categorical(torch.rand([N,K])) 
print("mix shapes", mix.batch_shape, mix.event_shape) # [J],[]

comp_batch = dist.Normal(torch.randn(N, K, P),1.)
print("comp_batch shapes:",comp_batch.batch_shape, comp_batch.event_shape) # [N x K x P], []

comp = comp_batch.to_event(1) 
#comp = dist.Independent(comp_batch, reinterpreted_batch_ndims = 1) # equivalent
print("comp shapes:", comp.batch_shape, comp.event_shape) # [N x K], [P]

mixture_batch = dist.MixtureSameFamily(mix, comp)
print("mixture_batch shapes:", mixture_batch.batch_shape, mixture_batch.event_shape) # [N], [P]

mixture = mixture_batch.to_event(1)
print("mixture shapes:", mixture.batch_shape, mixture.event_shape) # [], [N x P]

print(mixture.log_prob(y))

But the equivalent for a mixture of Binomials fails:

pyro.enable_validation(True)

# simulate data
total_count = 100
y = dist.Binomial(total_count, torch.full([N,P],0.5)).sample() 
print("y shape", y.shape)

mix = dist.Categorical(torch.rand(K)).expand([N])
#mix = dist.Categorical(torch.rand([N,K])) # same error
print("mix shapes", mix.batch_shape, mix.event_shape) # [J],[]

comp_batch = dist.Binomial(total_count, torch.rand(N, K, P))
print("comp_batch shapes:",comp_batch.batch_shape, comp_batch.event_shape) # [N x K x P], []

#comp = comp_batch.to_event(1) 
comp = dist.Independent(comp_batch, reinterpreted_batch_ndims = 1) # equivalent
print("comp shapes:", comp.batch_shape, comp.event_shape) # [N x K], [P]

mixture_batch = dist.MixtureSameFamily(mix, comp)
print("mixture_batch shapes:", mixture_batch.batch_shape, mixture_batch.event_shape) # [N], [P]

mixture = mixture_batch.to_event(1)
print("mixture shapes:", mixture.batch_shape, mixture.event_shape) # [], [N x P]

print(mixture.log_prob(y))

gives

.../torch/distributions/constraints.py in check(self, value)
    297     def check(self, value):
    298         return (
--> 299             (value % 1 == 0) & (self.lower_bound <= value) & (value <= self.upper_bound)
    300         )
    301 
RuntimeError: The size of tensor a (10) must match the size of tensor b (2) at non-singleton dimension 1

Turning off validation it runs.