Trouble with custom Torch distribution and batch_shape

So I am trying to do some modeling using a Plackett-Luce distribution. I found project on github that implements it here:

I am getting everything to work in terms of building a model, doing inference with SVI using AutoDiagonalNorm to the point that I know it’s working because it outputs fairly accurate loc and scale parameters:

guide.requires_grad_(False)

for name, value in pyro.get_param_store().items():
    print(name, pyro.param(name))

AutoDiagonalNormal.loc Parameter containing:
tensor([-0.6924, -0.6104, -0.3954,  1.6901,  2.7612, -0.0822,  0.0528,  0.0675,
        -0.0353,  0.0311, -0.1520,  0.8171, -0.0276, -0.1319,  0.9564, -0.7052,
         0.6423, -0.0239, -0.4948,  0.2907,  0.2828])
AutoDiagonalNormal.scale tensor([0.4231, 0.2462, 0.1886, 0.3141, 0.2906, 0.2132, 0.3147, 0.1858, 0.1911,
        0.1777, 0.1913, 0.2880, 0.1956, 0.1833, 0.3105, 0.2190, 0.2974, 0.2267,
        0.2141, 0.2525, 0.1954])

So far so good. The trouble in paradise is that when I try to sample using Predictive I get the following error regarding batch_shape:

    site_stats = {}
    for k, v in samples.items():
        site_stats[k] = {
            "mean": torch.mean(v, 0),
            "std": torch.std(v, 0),
            "5%": v.kthvalue(int(len(v) * 0.05)),
            "95%": v.kthvalue(int(len(v) * 0.95)),
        }
    return site_stats


predictive = Predictive(model, guide=guide, num_samples=800, return_sites=("obs","_RETURN"))
samples = predictive(data)
pred_summary = summary(samples)
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-242-4df7bb95db13> in <module>
     12 
     13 predictive = Predictive(model, guide=guide, num_samples=800, return_sites=("obs","_RETURN"))
---> 14 samples = predictive(data)
     15 pred_summary = summary(samples)

~/pyro/xpm/pyro/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    720             result = self._slow_forward(*input, **kwargs)
    721         else:
--> 722             result = self.forward(*input, **kwargs)
    723         for hook in itertools.chain(
    724                 _global_forward_hooks.values(),

~/pyro/xpm/pyro/lib/python3.7/site-packages/pyro/infer/predictive.py in forward(self, *args, **kwargs)
    199                                             parallel=self.parallel, model_args=args, model_kwargs=kwargs)
    200         return _predictive(self.model, posterior_samples, self.num_samples, return_sites=return_sites,
--> 201                            parallel=self.parallel, model_args=args, model_kwargs=kwargs)
    202 
    203     def get_samples(self, *args, **kwargs):

~/pyro/xpm/pyro/lib/python3.7/site-packages/pyro/infer/predictive.py in _predictive(model, posterior_samples, num_samples, return_sites, return_trace, parallel, model_args, model_kwargs)
     68     return_site_shapes = {}
     69     for site in model_trace.stochastic_nodes + model_trace.observation_nodes:
---> 70         append_ndim = max_plate_nesting - len(model_trace.nodes[site]["fn"].batch_shape)
     71         site_shape = (num_samples,) + (1,) * append_ndim + model_trace.nodes[site]['value'].shape
     72         # non-empty return-sites

AttributeError: 'PlackettLuce' object has no attribute 'batch_shape'

I’m not sure what to do here. I tried using TorchDistributionMixin as the base class but that raises an error in the inference step:

ValueError: Error while computing log_prob at site 'obs':
shape mismatch: objects cannot be broadcast to a single shape: (21,) vs torch.Size([21, 20])
Trace Shapes:        
 Param Sites:        
Sample Sites:        
   sigma dist       |
        value       |
     log_prob       |
   theta dist    20 |
        value    20 |
     log_prob    20 |
     obs dist    21 |
        value 21 20 |
guide.requires_grad_(False)

for name, value in pyro.get_param_store().items():
    print(name, pyro.param(name))

Any help would be greatly appreciated!

@thecity2 You will need to declare batch_shape in the constructor, something like

def __init__(self, logits):
    batch_shape = logits.shape[:-1]
    event_shape = logits.shape[-1:]
    super(PlackettLuce, self).__init__(batch_shape, event_shape)

You might also need to

  • inherit your class from TorchDistribution: from pyro.distributions.torch_distribution import TorchDistribution
  • implement .expand method
  • make sample work with sample_shape (instead of num_samples)

like in multivariate studentt implemenetation.