Stiefel manifold¶
📚 This page contains original research. To cite the Modula docs, here’s some BibTeX:
@misc{modula-docs,
author = {Jeremy Bernstein},
title = {The Modula Docs},
url = {https://docs.modula.systems/},
year = 2025
}
On this page we shall consider a problem that I affectionately refer to as manifold Muon—or, more formally, the problem of steepest descent under the spectral norm on the Stiefel manifold. This problem arises when one is interested in taking the best possible optimization step in a spectral norm geometry (useful for accelerating training) while keeping the size of the weight matrices tightly regulated (potentially helpful for training stability and removing learning rate confounders). This page will generalize the analysis from the square case to the full Stiefel manifold.
I posed manifold Muon as an open problem on the Modula docs earlier this year, and two researchers Franz Louis Cesista (a.k.a. Leloy) and Jianlin Su recently proposed solutions. Leloy proposed a heuristic solution via alternating projections, and Jianlin solved the problem by setting up a fixed point iteration. I heard about Leloy’s work and an early version of Jianlin’s approach (which did not yet work) and managed to solve the problem myself with a slightly different approach based on Lagrangian duality, which I will present in the next section. I also want to acknowledge that Cédric Simal independently proposed studying the dual problem to me and Leloy, after I had worked out the following analysis.
Formulating the problem¶
Let’s set up the problem mathematically. Say we have a matrix-valued optimization variable \(W \in \mathbb{R}^{m \times n}\) where, without loss of generality, we take \(m\geq n\) so that the matrix has more rows than columns. And we have a cost function \(\mathcal{C}:\mathbb{R}^{m \times n}\to\mathbb{R}\) that we would like to minimize. We would also like to constrain the matrix \(W\) to the following set:
This set is known as the Stiefel manifold. A matrix \(W\in\mathsf{Stiefel}(m,n)\) for \(m>n\) is known as a semi-orthogonal matrix—since it has too few columns to form a complete orthonormal basis. There are various alternative ways to characterize the Stiefel manifold. For example, it is equivalently defined as the set of \(m \times n\) matrices with unit \(\ell_2 \to \ell_2\) condition number. Suffice to say, the Stiefel manifold is a very well-behaved class of matrices.
We would like to be able to take optimization steps that lie tangent to this manifold. Just as in the square case, we can show that the tangent space to the Stiefel manifold at semi-orthogonal matrix \(W\in\mathsf{Stiefel}(m,n)\) is given by the following linear subspace of the ambient matrix space \(\mathbb{R}^{m \times n}\):
In the context of Riemannian optimization, there are established means of projecting the gradient to this linear subspace in order to take steps tangent to the Stiefel manifold. But to make life more interesting, we shall be interested in cost functions with a different sort of structure. In particular, suppose our cost \(\mathcal{C}\) is Lipschitz-smooth in the spectral norm:
where \(\langle \nabla \mathcal{C}(W), \Delta W\rangle \equiv \operatorname{trace} \nabla \mathcal{C}^\top \Delta W\) is the Frobenius inner product between the derivative of the cost and the weight update, measuring the first-order change in cost. To motivate this smoothness structure, observe that matrices in a neural network act as operators on vectors, and the spectral norm respects this fact—see our anthology for more on this. Spectral norm smoothness suggests taking optimization steps of controlled spectral norm. And since the spectral norm does not emerge from an inner product, spectral norm smoothness takes us outside the realm of Riemannian geometry.
All told, we would like to design a gradient descent algorithm whose updates exploit the spectral norm geometry of the cost function while lying tangent to the Stiefel manifold. Here we focus on the problem of choosing the direction of the step given these constraints, and offload the problem of choosing the magnitude to the learning rate. We formulate the optimal update direction as the matrix \(A\) that solves the following minimization problem:
In this expression, \(W\) is the current point on the manifold, \(G := \nabla \mathcal{C}(W)\) is shorthand for the derivative of the cost, and \(A\) is the update direction that we seek. In words, we want to find an update direction that squeezes out the most linear improvement in cost while lying inside the ball of unit spectral norm and also lying tangent to the Stiefel manifold.
Solving manifold Muon via Lagrangian duality¶
Similar to Jianlin’s approach, we introduce a matrix \(\Lambda\in\mathbb{R}^{n\times n}\) of Lagrange multipliers, and define a Lagrangian function \(\mathcal{L}(A, \Lambda)\) that incorporates the tangent space constraint:
where the second equality follows by applying the cyclic property of the trace and transposing one term. One can check that our original problem (1) is equivalent to the saddle point problem \(\min_{\|A\|_\mathrm{spectral} \leq 1} \max_{\Lambda} \mathcal{L}(A,\Lambda)\) since for any \(A\) that violates the tangent space constraint, the inner maximization with respect to \(\Lambda\) would send the Lagrangian to infinity. By Sion’s minimax theorem, we can swap the order of the \(\min\) and \(\max\) to obtain:
Following an argument which is now standard in Muon lore, we recognize the optimal value \(A_\mathrm{opt}(\Lambda)\) of the primal variable \(A\) for a given dual variable \(\Lambda\) as:
where \(\operatorname{msign}\) is the matrix sign function, defined as the elementwise sign function applied to the singular values of a matrix, or in PyTorch code:
import torch
def msign(X):
U, S, V = torch.svd(X)
return U @ S.sign().diag() @ V.T
Note that \(\operatorname{msign}\) can be computed efficiently on GPUs without taking an SVD via Newton-Schulz iteration as in the recent Polar Express algorithm.
Substituting \(A_\mathrm{opt}(\Lambda)\) back into the Lagrangian, we uncover the dual problem:
In contrast to the primal problem (1), the dual problem is completely unconstrained. We may solve the dual problem by running gradient ascent on the Lagrangian dual function \(\mathcal{L}(A_\mathrm{opt}(\Lambda), \Lambda)\)—a technique formally known as dual ascent. After some work, the gradient of the dual function—or, more precisely, a subgradient—is given by the following formula:
To obtain this expression, we have applied the chain rule and the fact that \(\operatorname{msign}(X)\) is in the subdifferential of \(\|X\|_\mathrm{nuclear}\).
This expression for \(H(\Lambda)\) also has an intuitive interpretation: it measures the deviation of the current setting of \(A_\mathrm{opt}(\Lambda)\) from satisfying the tangent space condition. Jianlin’s solution can be interpreted as running a fixed point iteration on the first-order optimality condition for the dual problem: \(H(\Lambda_\mathrm{opt}) = 0\). Instead of running this fixed point iteration, we propose a different approach known as dual ascent.
The dual ascent algorithm¶
In this section, we write down a gradient ascent algorithm to solve the Lagrangian dual problem. Given a tolerance \(\mathtt{tol}>0\) and a step size \(\alpha>0\) for updating the dual variable \(\Lambda\), the algorithm is given by:
Initialize the dual variable: \(\Lambda = -\tfrac{1}{4} \times (W^\top G + G^\top W)\).
Compute the candidate update direction: \(A = - \operatorname{msign}(G + 2W \Lambda)\).
Measure the deviation of \(A\) from the tangent space: \(H = W^\top A + A^\top W\).
Check the stopping criterion:
If the deviation is small enough, i.e. \(\|H\|_\mathrm{F} / \sqrt{mn} < \mathtt{tol}\), then return \(A\).
Otherwise, update the dual variable: \(\Lambda \gets \Lambda + \alpha \times H\) and go back to step 2.
Observe that the dual variable \(\Lambda\) remains symmetric throughout this procedure, so we can use \(2 \Lambda\) in place of \(\Lambda + \Lambda^\top\) at step 2. The motivation for the special initialization of \(\Lambda\) is that it leads to the algorithm terminating on the first step if \(W\) is square. This is because step 2 already recovers the optimal value of \(A\) for the square case and so \(H=0\) at step 3. In actual neural network training, where \(G\) may not change much between steps because of momentum, it might make more sense to warm start \(\Lambda\) from the previous iteration.
Once this algorithm terminates, we take the returned value of the primal variable \(A\) and make the tangent space update \(W \gets W + \eta \times A\). The final step is to retract the updated weights back to the manifold. We will work out a retraction map in the next section.
Working out the retraction map¶
An update in the tangent space will diverge slightly from the manifold for finite step sizes \(\eta\). As such we need to find a retraction map to project the updated weights back to the manifold. It turns out that the retraction map can be implemented in a simple way, by introducing an extra matrix \(C\) to the update:
We just need to solve for the proper value of \(C\). Checking the semi-orthogonality condition and using the fact that \(W^\top A + A^\top W = 0\) because the update direction \(A\) belongs to the tangent space, we find that:
Even though \(A\) is an output of \(\operatorname{msign}\), it may not hold that \(A^\top A = I_n\) because \(A\) may be low rank. We need to find a matrix \(C\) satisfying \(C^\top[I_n - A^\top A + (1+\eta^2) \cdot A^\top A]C = I_n\). This task is made substantially easier by observing that \(A^\top A\) and \(I_n - A^\top A\) are orthogonal projectors. We can then read off a suitable value for \(C\) as:
While it is nice to have an analytical expression for the retraction map, in practice it might be numerically advantageous just to use \(\operatorname{msign}\) to project the updated weights back to the manifold.
PyTorch implementation¶
Here we give a basic PyTorch implementation for solving manifold Muon via dual ascent. The code re-uses the msign
function defined earlier in the post.
import math
def manifold_muon(W, G, eta=0.1, alpha=0.01, steps=100, tol=1e-6):
# Ensure that W and G are both tall matrices
should_tranpose = W.shape[0] < W.shape[1]
if should_tranpose:
W = W.T
G = G.T
# Initialize the dual variable
Lambda = -0.25 * (W.T @ G + G.T @ W)
# Ascend on the dual problem to find the update direction A
for step in range(steps):
# Update the candidate direction A
A = msign(G + 2 * W @ Lambda)
# Measure deviation of A from the tangent space:
H = W.T @ A + A.T @ W
# Check the stopping criterion
if torch.norm(H) / math.sqrt(H.numel()) < tol:
break
# Update the dual variable
Lambda -= alpha * (1 - step / steps) * H
# Descend on the primal problem
new_W = W - eta * A
# Retract to the manifold
new_W += new_W @ A.T @ A * (1/math.sqrt(1 + eta**2) - 1)
# Restore the shape of the solution and return
return new_W.T if should_tranpose else new_W
Acknowledgments¶
I am grateful to Leloy and Jianlin Su for sharing their excellent work on this topic. I also want to acknowledge Cédric Simal who independently proposed studying the dual problem to me, after I had worked out this dual ascent approach. I am incredibly grateful to the team at Thinking Machines for supporting me to explore this problem. Any mistakes in this writeup are my own responsibility.