Is there a constraint to inforce orthogonality?

I’m trynig to implement probabilistic PCA. I’m getting the componenets, but, they’re not orthogonal to each other, additionally, they don’t have a unit length. I’m wondering how can we enforce those constraints in a model. For reference here’s my simple code:

import torch
import pyro
import matplotlib.pyplot as plt
import numpy as np
import pyro.infer
import pyro.optim
import pyro.distributions as dist
from torch.distributions import constraints
from sklearn import datasets
import pyro.poutine as poutine
from sklearn import decomposition
from sklearn import preprocessing


pyro.enable_validation(True)  # <---- This is always a good idea!

pyro.set_rng_seed(101)
d, D = 2, 4  # small dimension d, large dimension D.
iris = datasets.load_iris()
y = iris.target
X = iris.data

scaler = preprocessing.StandardScaler(with_std=False)
X = scaler.fit_transform(X)
X = torch.tensor(X, dtype=torch.float32)


def ppca(data):
    A = pyro.param("A", torch.zeros((D, d)))
    sig = pyro.param("sig", torch.ones(1), constraint=constraints.positive)
    # mu = pyro.param("mu", torch.zeros(D))
    for i in pyro.plate("data", len(data)):
        z = pyro.sample("latent_{}".format(i), dist.Normal(torch.zeros(d), 1.0).to_event(1))
        pyro.sample("observed_{}".format(i), dist.Normal(A @ z, sig).to_event(1), obs=data[i])

#%%


def guide(data):
    A = pyro.param("A", torch.zeros((D, d)))
    # mu_ = pyro.param("mu_", torch.zeros(D))
    for i in pyro.plate("data", len(data)):
        pyro.sample("latent_{}".format(i), dist.Normal(A.T @ data[i], 1.0).to_event(1))


pyro.clear_param_store()
svi = pyro.infer.SVI(model=ppca,
                     guide=guide,
                     optim=pyro.optim.Adam({"lr": 0.001}),
                     loss=pyro.infer.Trace_ELBO())

losses = []
num_steps = 2500
for t in range(num_steps):
    loss = svi.step(X)
    losses.append(loss)
    if t % 100 == 0:
        print(f'step = {t}, loss = {loss}', )

plt.plot(losses)
plt.title("ELBO")
plt.xlabel("step")
plt.ylabel("loss")

not as such, however you can use a HouseHolderFlow to get orthonormality:

import torch
from pyro.distributions.transforms import HouseholderFlow

dim = 3
n_columns = 2
householder = HouseholderFlow(input_dim=dim, count_transforms=dim)
ortho_columns = householder(torch.eye(dim))[:n_columns]

print("dot product:", torch.dot(ortho_columns[0], ortho_columns[1]).item())
print("ortho_columns\n", ortho_columns.data.numpy())

note that you would need to run householder(....) in every iteration of your algorithm if you want gradients to flow correctly

You might consider using a non-orthogonal parametrization during training and only pulling out orthogonal components after training. Then you can use a LowRankMultivariateNormal distribution with only a constraints.positive on the cov_diag term. After training you can pull out torch.svd(d.covariance). This is similar to the approach of AutoLowRankMultivariateNormal.