@fritzo Thanks for the quick response. A couple of questions arise when considering your answer.
Firstly can you elaborate on the non-physicality of a sparse cov matrix? Is this a statistical property of the cov matrix? Or do you mean that in general real world problems the cov mat is never sparse?
Secondly, in you example guide you never call pyro.param, which I think is just due to brevity of the example, and the params to train are the locs and the half_prec matrix?
Lastly, I wonder what the role of the mask argument of the model is.
Right now im trying to implement the following autoguide. Guide will generate a trainable pyro.param of length l=(2*bands+1)N_latent-2sum(1,bands+1) these enties will end up in the tripple banded diagonal matrix which should be used as cov matrix (or as precision matrix if the latter is not possible) when the guides get_posterior is called. Quick prototype of the guide:
def generate_symm_banded_tensor(values, tensor_size=20, bands=5, wrap=False):
'''
Generate n banded symmetric matrix with param: values as entries.
'''
if wrap:
indices = torch.LongTensor([
[i, k%tensor_size] for i in range(tensor_size) for k in range(i, i+bands+1)])
else:
indices = torch.LongTensor([[i, k] for i in range(tensor_size) \
for k in range(i, i+bands+1) if (k>=0 and k<tensor_size) ])
if not(type(values) is torch.Tensor and values.shape[0] == indices.shape[0]):
raise Exception('Provide a tensor of length l=(bands+1)N-sum(1,bands+1)')
upper_triangular = torch.sparse.FloatTensor(
indices.t(), values, torch.Size([tensor_size, tensor_size])).to_dense()
complete = upper_triangular.T.mm(upper_triangular)
return complete
class NearestNeighbourCorrelationNormalGuide(AutoContinuous):
def __init__(self, model, prefix="auto", n=1):
self.bands = n
super().__init__(model, prefix=prefix)
def get_posterior(self, *args, **kwargs):
"""
Returns a MultivariateNormal posterior distribution.
"""
loc = pyro.param("{}_loc".format(self.prefix), self._init_loc)
n_bands_elements = pyro.param("{}_sparse_cov_mat".format(self.prefix),
lambda: loc.new_ones(
(self.bands +1) * self.latent_dim - np.sum(np.arange(1, self.bands+1))),
constraint=constraints.positive
)
print(n_bands_elements.shape)
dense_cov_mat = generate_symm_banded_tensor(
n_bands_elements,
tensor_size=self.latent_dim,
bands=self.bands)
print(dense_cov_mat, dense_cov_mat.dtype)
return dist.MultivariateNormal(loc, covariance_matrix=dense_cov_mat)
def _loc_scale(self, *args, **kwargs):
loc = pyro.param("{}_loc".format(self.prefix))
scale = pyro.param("{}_cov_mat".format(self.prefix)).diag()
#scale = scale_arr[0] + scale_arr[2::3] + scale_arr[-1] # negative indexing prob does not work
return loc, scale
This seems to work, but im sure the way I create the cov matrix every time get_posterior is called is really inefficient. I would like to hear from you if you have any suggestions on this part and answers to the questions I posted above.
I would like to thank you and the other active members for the involvement on this forum!
Abel