Newton-Schulz¶
On this page, we will work out a family of iterative algorithms for “orthogonalizing” a matrix, by which we mean transforming either the rows or the columns of the matrix to form an orthonormal set of vectors. These so-called “Newton-Schulz” iterations are a useful family of algorithms to keep in your toolbox. We proposed using these iterations for neural net optimization in our paper:
Jeremy Bernstein & Laker NewhousearXiv 2024
Before that, we included the iteration in an appendix of our workshop paper, and before that I actually worked out the ideas directly on Twitter with my collaborator Tim Large. We used a particular cursed quintic iteration in the Muon optimizer, which was used to set speed records for training NanoGPT:
Keller Jordan, Yuchen Jin, Vlado Boza, Jiacheng You, Franz Cesista, Laker Newhouse & Jeremy Bernsteinblog post 2024
Since then, the iteration has been applied in new optimizers such as Scion, improved SOAP and Mango. At the bottom of this page, we provide further historical connections on the techniques.
Problem statement¶
We wish to approximate the map that sends a matrix \(M\in\mathbb{R}^{m\times n}\) with reduced SVD \(M = U \Sigma V^\top\) to the matrix \(U V^\top\). This map can be thought of as “snapping the singular values of \(M\) to one”—with the exception that the iterations we consider will actually fix zero singular values at zero. But ignoring this detail, the map is given by:
This operation is sometimes referred to as “symmetric orthogonalization” because no row or column of the matrix \(M\) is treated as special in the procedure. This is in contrast to Gram-Schmidt orthogonalization, which involves first picking out a certain row or column vector as special and then orthogonalizing the remaining vectors against this vector.
Odd polynomial iterations¶
We will consider iterations based on odd matrix polynomials of the form:
which acts on a matrix \(X \in \mathbb{R}^{m \times n}\). The important property of an odd matrix polynomial of this form is that it commutes with the singular value decomposition, in the sense that:
So, to apply an odd polynomial \(p\) to the singular values, it is enough to apply it to the overall matrix \(X\). Since the matrix of singular values \(\Sigma\) is diagonal, this reduces to applying the scalar polynomial
to the diagonal entries of \(\Sigma\). In what follows we will simply specify formulae for scalar polynomials \(f\) with the understanding that they will be extended to matrix polynomials \(p\) as specified above. Then our task is just to produce odd scalar polynomials \(f(x)\) that when iterated like \(f \circ f \circ f \circ ... \circ f(x)\) converge to the sign function \(\operatorname{sign}(x)\).
A cubic iteration¶
We begin with the simplest Newton-Schulz iteration, based on the cubic polynomial:
We plot \(f(x)\) on the left and on the right we plot \(f(x)\) iterated five times to yield \(f(f(f(f(f(x)))))\).
As can be seen, by iterating \(f\) several times, the graph starts to resemble that of the sign function \(\operatorname{sign}(x)\), at least on the interval close to the origin. In fact, you can check that if we iterate \(f\) an infinite number of times, we will obtain precisely the sign function on the interval \([-\sqrt{3},\sqrt{3}]\). As a consequence, if we iterate the corresponding matrix polynomial \(p(X) = \frac{3}{2}X - \frac{1}{2}XX^\top X\), we will approximate the sign function element-wise on the singular values of \(X\), thereby orthogonalising the matrix. The only caveat is that we need to ensure all singular values of the initial matrix lie in the interval \([-\sqrt{3},\sqrt{3}]\). We can achieve this via a simple pre-processing step, mapping \(X \mapsto X / \|X\|_F\).
A quintic iteration¶
Using a higher-order polynomial provides more degrees of freedom in our design space, which we can use to obtain faster convergence. In this section, we consider the quintic iteration given by:
which is actually implemented in the Modula package for dualizing linear layers. Again, we plot one and five iterations of this polyomial:
As can be seen, after 5 iterations the quintic iteration has achieved a substantially closer approximation to the sign function than the cubic iteration, at least on the interval \([-3/2,3/2]\).
A cursed quintic iteration¶
We applied a Newton-Schulz iteration in the Muon optimizer used in the NanoGPT speedrun. Keller experimented with tuning the coefficients in the iteration and found that the most important thing for fast convergence of the optimizer was to inflate the small singular values as fast as possible. And to keep the wall-clock time low, he needed to do this in the smallest number of iterations possible. This is achieved by making the first coefficient in the polynomial as large as possible, thereby maximizing the slope of the polynomial at \(x=0\). Keller settled on the following iteration:
Plotting the polynomial after one and five iterations, we see some peculiar behaviour:
This iteration oscillates and in fact does not converge! To see why, observe that a convergent iteration must at the very least satisfy \(f(1) = 1\) so that \(x=1\) is a fixed point. In turn, this implies that the sum of the coefficients should equal 1. But for Keller’s polynomial, the coefficients sum to
In short, the cursed quintic iteration sacrifices convergence for speed.
Designing your own iteration¶
Designing these polynomial iterations can be a surprisingly fun exercise. If you’d like to explore designing your own iteration, you can start with a polynomial of the form:
And then choose the coefficients \(a,b,c,d,e,...\) to achieve your desired behaviour. Two important things to consider are:
What order do you want to truncate at? A higher-order iteration can converge in fewer steps, but each step is more expensive. There is a trade-off here.
Do you want the iterations to converge? If so, you at least need to enforce that the coefficients sum to 1 so that \(f(1) = 1\). You could consider enforcing additional derivative conditions, such as that \(\partial f / \partial x = 0\) at \(x=1\), to further stabilize the convergence.
After making these decisions, you may have leftover degrees of freedom. A fun way to fix these degrees of freedom is to open up Desmos and play around with the coefficients using sliders.
Historical connections¶
The procedure of symmetric orthogonalization appears in a number of different contexts:
it is used in solving the orthogonal Procrustes problem.
it computes the “orthogonal polar factor” in the polar decomposition of a matrix.
it was used by Per-Olov Löwdin in the 1950s to perform atomic and molecular orbital calculations.
it is used for doing Frank-Wolfe optimization over the spectral norm ball.
it was proposed for deep learning optimization in the paper “preconditioned spectral descent for deep learning”—albeit computed via matrix sketching rather than Newton-Schulz iterations.
A Newton-Schulz iteration was used to orthogonalize the weight matrices (but not the updates!) in deep learning in the paper “sorting out Lipschitz function approximation”.
The earliest references on the Newton-Schulz iteration itself seem to be “some iterative methods for improving orthonormality” (Kovarik, 1970) and “an iterative algorithm for computing the best estimate of an orthogonal matrix” (Björck & Bowie, 1971). To justify using the name “Newton-Schulz” for these iterations, we note that Higham used it in these slides. The idea of graphically tuning the coefficients of the iteration to obtain certain performance characteristics is, to the best of my knowledge, our own original idea.