Hello, World!ΒΆ
On this page, we will build a simple training loop to fit an MLP to some randomly generated data. We start by sampling some data. Modula uses JAX to handle array computations, so we use JAX to sample the data. JAX requires us to explicitly pass in the state of the random number generator.
[1]:
import jax
import jax.numpy as jnp
input_dim = 784
output_dim = 10
batch_size = 128
key = jax.random.PRNGKey(0)
inputs = jax.random.normal(key, (input_dim, batch_size))
targets = jax.random.normal(key, (output_dim, batch_size))
Next, we will build our neural network. We import the basic Linear and ReLU modules. And we compose them by using the @
operator. Calling mlp.jit()
tries to make all the internal module methods more efficient using just-in-time compilation from JAX.
[2]:
from modula.atom import Linear
from modula.bond import ReLU
width = 256
mlp = Linear(output_dim, width)
mlp @= ReLU()
mlp @= Linear(width, width)
mlp @= ReLU()
mlp @= Linear(width, input_dim)
print(mlp)
mlp.jit()
CompositeModule
...consists of 3 atoms and 2 bonds
...non-smooth
...input sensitivity is 1
...contributes proportion 3 to feature learning of any supermodule
Next, we choose our error measure. Error measures allow us to both compute the loss of the model, and also to compute the derivative of the loss with respect to model outputs. For simplicity we will just use squared error.
[3]:
from modula.error import SquareError
error = SquareError()
Finally we are ready to train our model. The method mlp.backward
takes as input the weights, activations and the gradient of the error. It returns the gradient of the loss with respect to both the model weights and the inputs. The method mlp.dualize
takes in the gradient of the weights and solves for the vector of unit modular norm that maximizes the linearized improvement in loss.
[4]:
steps = 1000
learning_rate = 0.1
key = jax.random.PRNGKey(0)
w = mlp.initialize(key)
for step in range(steps):
# compute outputs and activations
outputs, activations = mlp(inputs, w)
# compute loss
loss = error(outputs, targets)
# compute error gradient
error_grad = error.grad(outputs, targets)
# compute gradient of weights
grad_w, _ = mlp.backward(w, activations, error_grad)
# dualize gradient
d_w = mlp.dualize(grad_w)
# compute scheduled learning rate
lr = learning_rate * (1 - step / steps)
# update weights
w = [weight - lr * d_weight for weight, d_weight in zip(w, d_w)]
if step % 100 == 0:
print(f"Step {step:3d} \t Loss {loss:.6f}")
Step 0 Loss 0.976274
Step 100 Loss 0.001989
Step 200 Loss 0.001537
Step 300 Loss 0.001194
Step 400 Loss 0.000885
Step 500 Loss 0.000627
Step 600 Loss 0.000420
Step 700 Loss 0.000255
Step 800 Loss 0.000134
Step 900 Loss 0.000053