Sampling from mvn distribution with only nearest neighbours in covariant matrix

Hi everyone,

Im working on a model where I would like to use a guide that creates a posterior over all my latent random variables using a MultiVariateNormalDistribution. I could use a AutoMultiVariateNormal but since I only want to allow for nearest neighbour correlations (ie neighbouring sample sites can be correlated) and my param space is pretty big, this seems like a bad idea. Now I’m looking at the AutoLowRankNormal guides but I do not really see where the off-diagonal entries end up in the actual covariant matrix. My covariant matrix would be ‘triple diagonal’ and zeros otherwise. Would the AutoLowRank be an option for this model? Or would I need a custom implementation?
As a custom implementation im considering to use 3 tensors as parameters and add an NxN zero- non-grad matrix when I need the whole covariant matrix for sampling. Is there a way that you know of to reuse the zero entries in this matrix, ie all except the diagonals point to the same zero, to reduce memory usage.

Thanks in advance

Abel

Hi @abelstam,

First I suspect you’ll probably want a sparse precision matrix rather than a sparse covariance matrix. Precision matrix = inverse of covariance matrix. Sparse precision matrices correspond to sparsity in the graphical model sense, whereas sparse covariance matrices are non-physical.

Either way I think you’ll need to implement a custom distribution or custom parameters. Here is one way that is statistically efficient but wastes a bit of computation:

class AutoSparseMVN(AutoContinuous):
    def __init__(self, model, mask, **kwargs):
        assert (mask == mask.transpose(-1,-2)).all()
        self.mask = mask
        super().__init__(model, **kwargs)

    def _setup_prototype(self, *args, **kwargs):
        super()._setup_prototype(*args, **kwargs)
        self.loc = nn.Parameter(self._init_loc())
        self.half_precision = nn.Parameter(torch.eye(
            self.latent_dim, dtype=self.loc.dtype, device=self.loc.device))

    def get_posterior(self, *args, **kwargs):
        precision = self.half_precision + self.half_precision.transpose(-1, -2))
        precision = precision * self.mask
        return dist.MultivariateNormal(self.loc, precision=precision)

@fritzo Thanks for the quick response. A couple of questions arise when considering your answer.

Firstly can you elaborate on the non-physicality of a sparse cov matrix? Is this a statistical property of the cov matrix? Or do you mean that in general real world problems the cov mat is never sparse?

Secondly, in you example guide you never call pyro.param, which I think is just due to brevity of the example, and the params to train are the locs and the half_prec matrix?

Lastly, I wonder what the role of the mask argument of the model is.

Right now im trying to implement the following autoguide. Guide will generate a trainable pyro.param of length l=(2*bands+1)N_latent-2sum(1,bands+1) these enties will end up in the tripple banded diagonal matrix which should be used as cov matrix (or as precision matrix if the latter is not possible) when the guides get_posterior is called. Quick prototype of the guide:

def generate_symm_banded_tensor(values, tensor_size=20, bands=5, wrap=False):
    '''
    Generate n banded symmetric matrix with param: values as entries.
    '''
    if wrap:
        indices = torch.LongTensor([
            [i, k%tensor_size] for i in range(tensor_size) for k in range(i, i+bands+1)])
    else:
        indices = torch.LongTensor([[i, k] for i in range(tensor_size) \
            for k in range(i, i+bands+1) if (k>=0 and k<tensor_size) ])
    if not(type(values) is torch.Tensor and values.shape[0] == indices.shape[0]):
        raise Exception('Provide a tensor of length l=(bands+1)N-sum(1,bands+1)')
    upper_triangular = torch.sparse.FloatTensor(
        indices.t(), values, torch.Size([tensor_size, tensor_size])).to_dense()
    complete = upper_triangular.T.mm(upper_triangular)
    return complete

class NearestNeighbourCorrelationNormalGuide(AutoContinuous):

def __init__(self, model, prefix="auto", n=1):
    self.bands = n
    super().__init__(model, prefix=prefix)

def get_posterior(self, *args, **kwargs):
    """
    Returns a MultivariateNormal posterior distribution.
    """
    loc = pyro.param("{}_loc".format(self.prefix), self._init_loc)
    n_bands_elements = pyro.param("{}_sparse_cov_mat".format(self.prefix),
        lambda: loc.new_ones(
            (self.bands +1) * self.latent_dim - np.sum(np.arange(1, self.bands+1))),
        constraint=constraints.positive
    )
    print(n_bands_elements.shape)
    dense_cov_mat = generate_symm_banded_tensor(
        n_bands_elements,
        tensor_size=self.latent_dim,
        bands=self.bands)
    print(dense_cov_mat, dense_cov_mat.dtype)
    return dist.MultivariateNormal(loc, covariance_matrix=dense_cov_mat)

def _loc_scale(self, *args, **kwargs):
    loc = pyro.param("{}_loc".format(self.prefix))
    scale = pyro.param("{}_cov_mat".format(self.prefix)).diag()
    #scale = scale_arr[0] + scale_arr[2::3] + scale_arr[-1] # negative indexing prob does not work
    return loc, scale

This seems to work, but im sure the way I create the cov matrix every time get_posterior is called is really inefficient. I would like to hear from you if you have any suggestions on this part and answers to the questions I posted above.

I would like to thank you and the other active members for the involvement on this forum!

Abel

can you elaborate on the non-physicality of a sparse cov matrix?

Very generally, our prior on systems is that they decompose into components that interact through a sparse dependency graph, in the sense of probabilistic graphical models. In the toy case of Gaussian components and linear dependencies, a sparse dependency graph correspond to sparse structure in the Fisher information = precision matrix, and does not lead to sparse structure in the covariance matrix. Thus in real world problems, the precision matrix is often sparse, but the covariance matrix is seldom sparse (unless it is completely block diagonal, in which case the precision matrix shares that block diagonal sparsity and often has even more sparsity).

This is a philosophical and methodological position, and I always recommend sparse precision matrices in practice. But it’s cheap to test both versions. If you do test both, I’d be interested to hear whether sparse precision or covariance lead to better fit :smile:

in you example guide you never call pyro.param

Our new AutoGuide classes all inherit from PyroModule, and therefore call pyro.param automatically on nn.Parameter access. Alternatively your pyro.param("{}_loc".format(self.prefix), ...) code will work, but it is more difficult to serve those models using torch.jit.script.

what the role of the mask argument of the model is.

I added the mask argument to allow you to specify the sparsity pattern of the precision matrix. Without that you’d be learning a full precision matrix.

Guide looks good on first reading.

1 Like