Orthogonal manifold¶

📚 This page contains original research. To cite the Modula docs, here’s some BibTeX:

@misc{modula-docs,
   author  = {Jeremy Bernstein},
   title   = {The Modula Docs},
   url     = {https://docs.modula.systems/},
   year    = 2025
}

On this page, we will work out an algorithm for performing gradient descent on the manifold of orthogonal matrices while taking steps that are steepest under the spectral norm. The algorithm will solve for the matrix of unit spectral norm that maximizes the linearized improvement in loss while lying tangent to the manifold. The “retraction map”—which sends the update from the tangent space back to the manifold—involves a few extra matrix multiplications.

Steepest descent on the orthogonal manifold¶

Consider a square weight matrix \(W\in\mathbb{R}^{n \times n}\) that is orthogonal, meaning that \(W^\top W = I_n\). Suppose that the “gradient matrix” \(G\in\mathbb{R}^{n\times n}\) is the derivative of some loss function evaluated at \(W\). Given step size \(\eta > 0\), we claim that the following weight update is steepest under the spectral norm while staying on the orthogonal manifold. First, we take the matrix sign of the skew part of \(W^\top G\):

\[X = \operatorname{msign}[W^\top G - G^\top W],\]

where the matrix sign \(\mathrm{msign}\) of a matrix \(M\) returns the matrix with the same singular vectors as \(M\) but all positive singular values are set to one. And then we make the update:

\[W \mapsto W \cdot (I_n - \eta X) \cdot \left(I_n - X^TX + \frac{X^TX}{\sqrt{1+\eta^2}}\right).\]

The final bracket constitutes the “retraction map”, which snaps the updated weights back to the manifold. Curiously, the update can be written in a purely multiplicative form.

Non-Riemannian manifold methods¶

One reason this algorithm is interesting is that it is an example of a manifold optimization algorithm that is non-Riemannian. A Riemmanian manifold is a manifold equipped with a structure called a Riemannian metric, which is an inner product defined at each point on the manifold. The inner product provides a way to measure distance and construct geometry-aware optimization algorithms. There has been a lot of research into Riemannian optimization methods. Some examples in a machine learning context are:

However, there has seemingly been much less research into optimization algorithms on manifolds equipped with non-Riemannian structures. For instance, a matrix manifold equipped with the spectral norm at every point is non-Riemannian since the spectral norm does not emerge from an inner product. But we believe these kinds of non-Riemannian geometries are very important in deep learning.

The structure of the tangent space¶

We would like to make a weight update so that the updated weights stay on the orthogonal manifold. First we need to figure out the structure of the “tangent space” at a point on the manifold. Roughly speaking, the tangent space is the set of possible velocities a particle could have as it passes through that particular point. So we need to consider all curves passing through the point on the manifold.

If we consider a curve \(W(t)\) on the manifold parameterized by time \(t \in \mathbb{R}\), then this curve must satisfy \(W(t)^\top W(t) = I_n\). Differentiating with respect to \(t\), we find that the velocity must satisfy:

\[\frac{\partial W(t)}{\partial t}^\top W(t) + W(t)^\top \frac{\partial W(t)}{\partial t} = 0.\]

So to be in the tangent space of a point \(W\) on the manifold, a matrix \(A\) must satisfy \(A^\top W + W^\top A = 0\). Conversely, if a matrix \(A\) satisfies \(A^\top W + W^\top A=0\), then it is the velocity of a curve on the manifold that passes through \(W\), as evidenced by the curve \(W(t) = W \exp(tW^\top A)\). Therefore, the tangent space at \(W\) is completely characterized by the set:

\[\{A\in \mathbb{R}^{n\times n}:A^\top W + W^\top A = 0\}.\]

Finally, if we use the orthogonal matrix \(W\) to make the change of variables \(A = W X\), then we see that \(A\) belongs to the tangent space at \(W\) if and only if \(X\) is skew-symmetric: \(X^\top + X = 0\). So the tangent space to the orthogonal manifold can be parameterized by skew-symmetric matrices.

Steepest direction in the tangent space¶

We will solve for the matrix \(A\) that belongs to the tangent space to the orthogonal manifold at matrix \(W\) and maximizes the linearized improvement in loss \(\operatorname{trace}(G^\top A)\) under the constraint that \(A\) has unit spectral norm. Formally, we wish to solve:

\[\operatorname{arg max}_{A\in \mathbb{R}^{n\times n}: \|A\|_*\leq 1 \text{ and } A^\top W + W^\top A = 0}\; \operatorname{trace}(G^\top A).\]

To simplify, we make the change of variables \(A = W X\) so that we now only need to maximize over skew-symmetric matrices \(X\) of unit spectral norm:

\[\operatorname{arg max}_{X\in \mathbb{R}^{n\times n}:\|X\|_*\leq 1 \text{ and } X^\top + X= 0}\; \operatorname{trace}([W^\top G]^\top X).\]

Next, we decompose \(W^\top G = \frac{1}{2}[W^\top G + G^\top W] + \frac{1}{2}[W^\top G - G^\top W]\) into its symmetric and skew-symmetric components and realize that, because \(X\) is skew-symmetric, the contribution to the trace from the symmetric part of \(W^\top G\) vanishes. So the problem becomes:

\[\operatorname{arg max}_{X\in \mathbb{R}^{n\times n}:\|X\|_*\leq 1 \text{ and } X^\top + X= 0}\; \operatorname{trace}\left(\left[\frac{W^\top G - G^\top W}{2}\right]^\top X\right).\]

If we simply ignore the skew-symmetric constraint, the solution for \(X\) is given by \(X = \operatorname{msign}[W^\top G - G^\top W]\). But this solution for \(X\) actually satisfies the skew-symmetric constraint! This is because the matrix sign function preserves skew-symmetry. An easy way to see this is that \(\operatorname{msign}[W^\top G - G^\top W]\) can be computed by running an odd polynomial iteration (see Newton-Schulz) on \(W^\top G - G^\top W\), and odd polynomials preserve skew-symmetry. [1]

Undoing the change of variables, our tangent vector is given by \(A = W \cdot \operatorname{msign}[W^\top G - G^\top W]\).

Finding the retraction map¶

The previous section suggests making the weight update \(W \mapsto W - \eta W X = W (I_n - \eta X)\). This update takes a step in the tangent space, which diverges slightly from the orthogonal manifold for finite step sizes. A relatively expensive way to fix this issue is to just apply the matrix sign function, i.e. \(W \mapsto \operatorname{msign}[W (I_n - \eta X)]\), to project the weights back to the manifold. But we will show in this section that there is actually a shortcut.

As a warmup, let’s first consider the case that \(W^\top G - G^\top W\) is full rank. Then \(X\) is an orthogonal matrix and \([W (I_n - \eta X)]^\top [W (I_n - \eta X)] = (1 + \eta^2) I_n\). Therefore, in this case, we can project back to the manifold simply by dividing the updated weights through by the scalar \(\sqrt{1+\eta^2}\).

In the general case where \(W^\top G - G^\top W\) and therefore \(X = \operatorname{msign}[W^\top G - G^\top W]\) may not be full rank, let us search for a matrix \(C\) such that \(W \cdot (I_n - \eta X) \cdot C\) is orthogonal. Checking the orthogonality condition \((W \cdot (I_n - \eta X) \cdot C)^\top (W \cdot (I_n - \eta X) \cdot C)=I_n\) reveals that we need to find a matrix \(C\) such that:

\[C^\top (I_n + \eta^2 X^\top X) C = I_n.\]

The trick is to recognize \(X^\top X\) as the orthogonal projector on to the row space of \(X\). The matrix \(I_n + \eta^2 X^\top X\) conserves vectors in the null space of \(X\) but scales up vectors in the row space of \(X\) by a factor of \(1+\eta^2\). It therefore suffices to choose a symmetric matrix \(C\) that inverts this transformation in two steps. Noting that \(I_n - X^\top X\) projects on to the null space of \(X\), the following choice of \(C\) is what we need:

\[C = C^\top = I_n - X^\top X + \frac{X^TX}{\sqrt{1+\eta^2}}.\]

Python code¶

Here is a basic JAX implementation for the algorithm:

import jax.numpy as jnp
import math

def orthogonalize(M, steps = 10):
   a, b, c = 3, -16/5, 6/5
   transpose = M.shape[1] > M.shape[0]
   if transpose:
      M = M.T
   M = M / jnp.linalg.norm(M)
   for _ in range(steps):
      A = M.T @ M
      I = jnp.eye(A.shape[0])
      M = M @ (a * I + b * A + c * A @ A)
   if transpose:
      M = M.T
   return M

def update(W, G, eta, NS_steps=10):
   I = jnp.eye(d)
   X = orthogonalize(W.T @ G - G.T @ W, NS_steps)
   retraction_factor = I - (1 - math.sqrt(1/(1+eta**2))) * X.T @ X
   return W @ (I - eta * X) @ retraction_factor

Open problem: Extending to the Stiefel Manifold¶

I initially thought that this solution easily extended to the Stiefel manifold—i.e. the set of \(m \times n\) semi-orthogonal matrices. But this turns out not to be the case: the algorithm we derived is generally not optimal if \(W\) is rectangular. To see this, let’s consider an \(m \times n\) matrix \(W\) with \(m > n\), and suppose that it belongs to the Stiefel manifold \(W^\top W = I_n\). The problem with our derivation is that the change of variables \(A = W X\) no longer parameterizes the full set of \(m \times n\) matrices. Instead, we need to make the change of variable \(A = WX + \overline{W}Y\) where the columns of \(\overline{W}\) are the “missing” columns of \(W\). In other words, the combined matrix \([W | \overline{W}]\) is a square orthogonal matrix. For this parameterization, the tangent space to the Stiefel manifold is obtained by requiring that \(X\in\mathbb{R}^{n\times n}\) is skew-symmetric while \(Y\in\mathbb{R}^{(m-n)\times n}\) is completely unconstrained. I do not know how to analytically solve the resulting maximization problem in this parameterization.