Von mises fisher 4D - scoring and sampling

I’d like to have the von mises fisher distribution in pyro, so I can score and sample it.

Sampling from that distribution is implemented in the Python package geomstats: https://github.com/geomstats/geomstats/blob/master/geomstats/geometry/hypersphere.py#L454

I am doing amortized inference, and I want to have a deep net that learns the mu (loc) and kappa (concentration) from data, so I need to use pyro.sample(von_mises_fisher_to_be_implemented, obs=data) in my guide.

I’m willing to do work to get this (PRs, etc). It’s a high priority for me. I would need some guidance, though.

Hi @geoffwoollard, it looks like dist.VonMises3D already implements .log_prob(), so you’d only need to implement a .sample() or better .rsample() method.

Bub before we get in to the details of .sample() and .rsample(), I’d like to try to persuade you to instead use dist.ProjectedNormal, possibly together with ProjectedNormalReparam. In black box variational inference, there is a variance reduction trick known as the “reparametrization trick” whereby samples can be differentiated wrt their parameters. This trick can dramatically speed up inference by replacing high-variance score function gradients with lower variance pathwise gradients. In torch.distributions, most distributions provide a “reparametrized sample” method named .rsample(), and a simpler non-differentiable .sample() method that leads to higher-variance gradients. But the .rsample() method is often much more difficult to implement than a .sample() method, because it needs to implement a backward method. For this reason, I gave up trying to implement a VonMises3D.rsample() method and found a different distribution, the projected Normal distribution, for which .rsample() is easy and inexpensive to implement. ProjectedNormal should be a drop-in replacement for VonMises3D: they both take a single concentration parameter.

If you absolutely need a VonMises3D distribution, I’d recommend subclassing locally and implementing a (higher variance) .sample() method by wrapping geomstats something like the following:

import geomstats
class VonMises3D(pyro.distributions.VonMises3D):
    def sample(self, sample_shape=torch.Size()):
        np_sample = geomstats...random_von_mises_fisher(...)
        return torch.tensor(np_sample, dtype=self.concentration.dtype)

If you’d like to do more work, you could try porting the geomstats...random_von_mises_fisher() implementation to PyTorch by creating a new pyro.distributions.VonMises3D.sample() method and contributing a PR. But again, I suspect ProjectedNormal.rsample() will work better than VonMises3D.sample() (due to the former’s use of the reparametrization trick), and the former will also be less work.

Good luck!

This is really helpful! Let me see if I can get things working with the projected normal.

Oh, I was going to learn unit quaternions… so I wanted it to take in a unit 4 vector… I got

NotImplementedError: ProjectedNormal.log_prob() is not implemented for dim = 4.

This works fine with n_dim=3, but fails with n_dim=4.

Would it take too much work to get this working???

qs = torch.arange(n_dim*n_batch).reshape(n_batch,n_dim).float().cuda()
q = torch.zeros(n_dim)
score_q = q.reshape(-1,1).expand(-1,n_batch).T.float().cuda()
KeyError                                  Traceback (most recent call last)
/usr/local/lib/python3.7/dist-packages/pyro/distributions/projected_normal.py in log_prob(self, value)
    104         try:
--> 105             impl = self._log_prob_impls[dim]
    106         except KeyError:

KeyError: 4

During handling of the above exception, another exception occurred:

NotImplementedError                       Traceback (most recent call last)
1 frames
/usr/local/lib/python3.7/dist-packages/pyro/distributions/projected_normal.py in log_prob(self, value)
    108             if value.requires_grad:  # For latent variables but not observations.
    109                 msg += " Consider using poutine.reparam with ProjectedNormalReparam."
--> 110             raise NotImplementedError(msg)
    111         return impl(self.concentration, value)

NotImplementedError: ProjectedNormal.log_prob() is not implemented for dim = 4.

Hi @geoffwoollard, I think it should be only a little work to get this working for dim=4, let me try quick…

Hi @geoffwoollard, Here’s a pull request supporting 4D ProjectedNormal. Does that work for you?

Yes, I think that should do it!

wow that is an amazingly fast response! thanks! I’m usually doing ! pip3 install pyro-ppl . How long will it take this patch to get through to that? Can I get it before that by installing from source? I’ve been developing in a colab notebook.

Until that PR merges, you can:

!pip install https://github.com/pyro-ppl/pyro/archive/projected-normal-4.zip

After that PR merges and until our next release (a few months from now) you can:

!pip install https://github.com/pyro-ppl/pyro/archive/dev.zip

which is what I often do in jupyter notebooks. After our next release you should be able to use your old

!pip install pyro-ppl

Hmm… I’m getting nans in my svi.step quite a bit, but not all the time.

Not sure how to trouble shoot this…

I tried decreasing the learning rate, increasing the batch size, using num_particles 10 to 100 in pyro.infer.Trace_ELBO(num_particles)

/usr/local/lib/python3.7/dist-packages/pyro/poutine/trace_struct.py:248: UserWarning: Encountered NaN: log_prob_sum at site 'quaternions'
  "log_prob_sum at site '{}'".format(name),
/usr/local/lib/python3.7/dist-packages/pyro/infer/trace_elbo.py:158: UserWarning: Encountered NaN: loss
  warn_if_nan(loss, "loss")
def model(mini_batch,do_log=True):
  forward model with noise
  point_cloud defined outside function
  size_mini_batch = mini_batch.shape[-1]
  with pyro.plate('mini_batch',size_mini_batch):

    quaternion_dist = dist.ProjectedNormal(concentration_cuda*q_loc_cuda)
    quaternions = pyro.sample('quaternions',quaternion_dist)
    rotations = pytorch3d.transforms.quaternion_to_matrix(quaternions)

    clean_signal = do_forward_model(rotations, point_cloud) # backpropable

    with pyro.plate('pixel_x',num_pix, dim=-2):
      with pyro.plate('pixel_y',num_pix, dim=-3):
        distrib = dist.Normal(clean_signal, sigma_noise)
        pyro.sample("shot_noise", distrib, obs=mini_batch) 
        sim = distrib.sample() 
        return sim, quaternions

def guide(mini_batch,do_log=True): # the proposal distribution
  neural net will be trained on many particles to predict params of distribution

  pyro.module("net", net)
  size_mini_batch = mini_batch.shape[-1]
  with pyro.plate('mini_batch', size_mini_batch, dim=-1):
    lam = net(mini_batch)
    q_loc,log_concentration = lam[:,:4], lam[:,4]
    q_unit_loc = q_loc / q_loc.norm(dim=1).reshape(-1,1)
    concentration = torch.exp(log_concentration)
    quaternion_dist = dist.ProjectedNormal(concentration.reshape(-1,1)*q_unit_loc)    
    quaternions = pyro.sample("quaternions", quaternion_dist)

    return quaternions, lam

svi = pyro.infer.SVI(model=model, 
                     optim=pyro.optim.Adam({"lr": 1e-7}), 

I also tried this with a simpler problem in 2D, where getting the scalar 2D rotation angle from von mises is working - in the sense that I can train a NN to do amortized inference and predict the rotation angle.

But swapping ProjectedNormal for von mises and doing the downstream x-y-vector to rotation conversions is giving nans.

I can still train in the 2D case and get good correlation of gt to prediction, but there are all these nans in the svi.step…

Interestingly the nan warnings stop after many epochs, when there is a dip in the loss curve (around 1500 on the x-axis).

Is this related to the variance in the gradient, and would using the ProjectedNormalReparam help?


Ok update, the same thing happens in the 3D case using the quaternions.

Hmm, the nans might be caused by either ProjectedNormal or by the parametrization. I’d actually first try directly parametrizing concentration and avoiding the coordinate transforms which can be numerically unstable. I.e. just let your net output concentration directly:

concentration = net(mini_batch)

If you really want to use the transformed parameters, you might consider numerically stabilizing as follows:

  q_loc,log_concentration = lam[:,:4], lam[:,4]
- q_unit_loc = q_loc / q_loc.norm(dim=1).reshape(-1,1)
+ q_unit_loc = q_loc / q_loc.norm(dim=1).reshape(-1,1).clamp(min=1e-20)
- concentration = torch.exp(log_concentration)
+ concentration = torch.nn.functional.softplus(log_concentration)

And note that concentration is real-valued, not positive, so I don’t see how it makes sense to exponentiate it:

concentration = torch.exp(log_concentration)

@geoffwoollard I do not how I have found myself reading this conversation, but if you want to sample unit quaternions to build a rotation matrix you can follow our implementation here: https://github.com/LysSanzMoreta/Theseus-PP/blob/master/SUPERPOSITION.py#L234-L285

We also had some problems with nan values etc, we resolved with the Student-t

I hope I understood correctly what you wanted to do :slight_smile:

1 Like

Hi @geoffwoollard, just checking in, were you able to get this working with the simpler parametrization?

1 Like

Yes. It worked. Very happy with it :slight_smile: