Is there a constraint to inforce orthogonality?

I’m trynig to implement probabilistic PCA. I’m getting the componenets, but, they’re not orthogonal to each other, additionally, they don’t have a unit length. I’m wondering how can we enforce those constraints in a model. For reference here’s my simple code:

import torch
import pyro
import matplotlib.pyplot as plt
import numpy as np
import pyro.infer
import pyro.optim
import pyro.distributions as dist
from torch.distributions import constraints
from sklearn import datasets
import pyro.poutine as poutine
from sklearn import decomposition
from sklearn import preprocessing

pyro.enable_validation(True)  # <---- This is always a good idea!

d, D = 2, 4  # small dimension d, large dimension D.
iris = datasets.load_iris()
y =
X =

scaler = preprocessing.StandardScaler(with_std=False)
X = scaler.fit_transform(X)
X = torch.tensor(X, dtype=torch.float32)

def ppca(data):
    A = pyro.param("A", torch.zeros((D, d)))
    sig = pyro.param("sig", torch.ones(1), constraint=constraints.positive)
    # mu = pyro.param("mu", torch.zeros(D))
    for i in pyro.plate("data", len(data)):
        z = pyro.sample("latent_{}".format(i), dist.Normal(torch.zeros(d), 1.0).to_event(1))
        pyro.sample("observed_{}".format(i), dist.Normal(A @ z, sig).to_event(1), obs=data[i])


def guide(data):
    A = pyro.param("A", torch.zeros((D, d)))
    # mu_ = pyro.param("mu_", torch.zeros(D))
    for i in pyro.plate("data", len(data)):
        pyro.sample("latent_{}".format(i), dist.Normal(A.T @ data[i], 1.0).to_event(1))

svi = pyro.infer.SVI(model=ppca,
                     optim=pyro.optim.Adam({"lr": 0.001}),

losses = []
num_steps = 2500
for t in range(num_steps):
    loss = svi.step(X)
    if t % 100 == 0:
        print(f'step = {t}, loss = {loss}', )


not as such, however you can use a HouseHolderFlow to get orthonormality:

import torch
from pyro.distributions.transforms import HouseholderFlow

dim = 3
n_columns = 2
householder = HouseholderFlow(input_dim=dim, count_transforms=dim)
ortho_columns = householder(torch.eye(dim))[:n_columns]

print("dot product:",[0], ortho_columns[1]).item())

note that you would need to run householder(....) in every iteration of your algorithm if you want gradients to flow correctly