Probabilistic model with large matrix multiplication inside

Hi all,

I’m trying to see if pyro would be good for solving inverse models degraded by noise. Specifically problems of the form

F = Hg + n

Where f is a recorded signal, g is the true object n is noise and H is is some corrupting measurement matrix.

As I think I understand it, using pyro I can build a forward model of HG + n and it’ll use pytorch backpropogation magic to iterate through to maximize likihood say. My question is will pyro handle this properly if H is a large matrix with only pseudo inverses available? Is large matrix inversion something that pytorch itself handles nicely?

Best

Craig

pytorch has a pinverse method but i’ve personally never used it

why do you need to invert H?

I think I’d misunderstood quite how pyro works and that I assumed that pytorch would invert H rather than optimise a loss akin to Hg-f=0.

i’m assuming H is known?

pyro supports various inference algorithms. for the simplest ones to try on this sort of problem, namely MAP/MLE and SVI you wouldn’t need to compute H’s (pseudo)inverse.

roughly speaking your model would look like

def model(F, H):
    g = pyro.sample("g", GPriorDistribution(...))
    Hg = torch.matmul(H, g)
    pyro.sample("F", NoiseDistribution(Hg, ...), obs=F)

you’d just need to fill in distributions appropriate to your problem

H is known yes, and each entry in the g vector is an independent Poisson distribution.

that is potentially a difficult problem since g is discrete (will depend on dimensionality etc.) but you could certainly give SVI a try. if the dimension is relatively moderate it could work

H is a rectangular sparse matrix (mostly zeros usually) up to 400000 square and g is a 1d vector up to 400000 in length

yeah that’s not an easy problem to solve, especially with black box methods. what is the magnitude of typical counts g you expect to encounter? can you replace the poisson distribution with a continuous relaxation? that’d make inference much easier

The problem I’m inverting is image formation specifically for microscopy. I was hoping to build the model up in complexity to include a time component (using markov chains) aswell rather than simplify it.

H is (most of the time) a toeplitz matrix, and so under most conditions the H*g bit is a convolution with a small kernel (which would help with computation I guess?)

I was also broadly assuming that the sparsity of H would help as it usually masks out contributions from all but a subset of nearby pixels to the pixel currently being observed.

In this case we use a Poisson distribution because thats physically representative of photons being emited and received at the detector. There are known iterative image reconstruction techniques that do maximum likelihood estimation (Richardson lucy deconvolution), but they mostly fall down when you add the time element into the model.

I’m definitely going to play around and figure out where the limits of computational viability come in, but it is accepted that large image deconvolution can be very slow.

Thanks for your guidance so far, you’ve been super helpful.

one thing you could do is

  • truncate the poisson distribution at some N_max (and renormalize)
  • then use something like RelaxedOneHotCategorical in conjunction with your truncated Poisson distribution to define a variational distribution
  • then do variational inference

basically you’d be using a continuous relaxation to compute approximate gradients but you’d still be using the desired poisson distribution in the model (though it’d be truncated).

still, this would be an approximation and it’s not clear a priori how well it’d work. perhaps poorly, hard to say.

Hi again,

I found time to have another attempt, I’ve managed to get MLE working for a 1D case with F = Hg+n (I’ve used x and y instead of g and F). This seems to be working well however I’d just like to know how one would consider the case of F= H*(g+n) as torch.matmul doesn’t work with tensor distributions. Do I have to manually manage the multiplication myself, should I propogate a dist.* to each element in the tensor? Can I run a pyro.sample of the distribution before I apply H and the use a condition <- this later one I tried but it didn’t work, code below.

Thanks for your advice so far. Truncating the possion distribution is very sensible for this use case as there is a physical upper limit in the imaging system.

# %% Imports
# Borrowed from: https://pyro.ai/examples/mle_map.html
import pyro
from torch.distributions import constraints
import pyro.distributions as dist
import matplotlib.pyplot as plt
import torch.distributions.constraints as constraints
import math
import os
import torch
import torch.distributions.constraints as constraints
from pyro.optim import Adam
from pyro.infer import SVI, Trace_ELBO
import pyro.distributions as dist
from scipy.linalg import circulant
import numpy as np

pyro.set_rng_seed(42)
# pyro.enable_validation(True)
# %% Setup tensors

x_true = torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5]).double()
x_0 = torch.tensor([0.5, 0.5, 0.5, 0.5, 0.5]).double()

H = torch.tensor(circulant([1, 0.5, 0, 0, 0.5])).double()

y = torch.matmul(H, x_true)
y_0 = torch.matmul(H, x_0)


# %% Define forward model


# %% y = H*(x+n)
@pyro.condition(data={"Hx": y})
def model(y, H):
    x = pyro.param("x", x_0, constraint=constraints.positive)
    x_prior = pyro.sample("x_prior", dist.Poisson(x))
    Hx = torch.matmul(H, x_prior)

# %% y = (H*x)+n
def model(y, H):
    x = pyro.param("x", torch.tensor(x_0), constraint=constraints.positive)
    Hx = torch.matmul(H, x)
    pyro.sample("f", dist.Poisson(Hx), obs=y)


# %% MLE means we can ignore the guide function

def guide(y, H):
    pass


# %%
# Setup the optimizer
adam_params = {"lr": 0.005}
optimizer = Adam(adam_params)

# Ssetup the inference algorithm
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())

loss = []
# %% Begin training
pyro.clear_param_store()
n_steps = 10000
for step in range(n_steps):
    loss.append(svi.step(y, H))
    if step % 100 == 0:
        print(".", end="")

print(f'Current guess: {pyro.param("x")}')
print(f"True x: {x_true}")
plt.plot(loss)

i don’t think does what you want since there is no prior on x. note that the left hand side of x = pyro.param("x", torch.tensor(x_0), constraint=constraints.positive) is a tensor that you can do matmuls on. there is not need to do matmuls on distributions.

I put a Poisson prior on x in the line

x_prior = pyro.sample("x_prior", dist.Poisson(x))

To make this x_prior variable that is sampled from the distribution. Is this not the correct way of handling f ~ H*Poisson(x)
The working model:

def model(y, H):
    x = pyro.param("x", torch.tensor(x_0), constraint=constraints.positive)
    Hx = torch.matmul(H, x)
    pyro.sample("f", dist.Poisson(Hx), obs=y)

Does MLE for f ~ Poisson(H*x) (I think)