Hello, World!ΒΆ

On this page, we will build a simple training loop to fit an MLP to some randomly generated data. We start by sampling some data. Modula uses JAX to handle array computations, so we use JAX to sample the data. JAX requires us to explicitly pass in the state of the random number generator.

[1]:
import jax
import jax.numpy as jnp

input_dim = 784
output_dim = 10
batch_size = 128

key = jax.random.PRNGKey(0)
inputs = jax.random.normal(key, (input_dim, batch_size))
targets = jax.random.normal(key, (output_dim, batch_size))

Next, we will build our neural network. We import the basic Linear and ReLU modules. And we compose them by using the @ operator. Calling mlp.jit() tries to make all the internal module methods more efficient using just-in-time compilation from JAX.

[2]:
from modula.atom import Linear
from modula.bond import ReLU

width = 256

mlp = Linear(output_dim, width)
mlp @= ReLU()
mlp @= Linear(width, width)
mlp @= ReLU()
mlp @= Linear(width, input_dim)

print(mlp)

mlp.jit()
CompositeModule
...consists of 3 atoms and 2 bonds
...non-smooth
...input sensitivity is 1
...contributes proportion 3 to feature learning of any supermodule

Next, we choose our error measure. Error measures allow us to both compute the loss of the model, and also to compute the derivative of the loss with respect to model outputs. For simplicity we will just use squared error.

[3]:
from modula.error import SquareError

error = SquareError()

Finally we are ready to train our model. The method mlp.backward takes as input the weights, activations and the gradient of the error. It returns the gradient of the loss with respect to both the model weights and the inputs. The method mlp.dualize takes in the gradient of the weights and solves for the vector of unit modular norm that maximizes the linearized improvement in loss.

[4]:
steps = 1000
learning_rate = 0.1

key = jax.random.PRNGKey(0)
w = mlp.initialize(key)

for step in range(steps):
    # compute outputs and activations
    outputs, activations = mlp(inputs, w)

    # compute loss
    loss = error(outputs, targets)

    # compute error gradient
    error_grad = error.grad(outputs, targets)

    # compute gradient of weights
    grad_w, _ = mlp.backward(w, activations, error_grad)

    # dualize gradient
    d_w = mlp.dualize(grad_w)

    # compute scheduled learning rate
    lr = learning_rate * (1 - step / steps)

    # update weights
    w = [weight - lr * d_weight for weight, d_weight in zip(w, d_w)]

    if step % 100 == 0:
        print(f"Step {step:3d} \t Loss {loss:.6f}")

Step   0         Loss 0.976274
Step 100         Loss 0.001989
Step 200         Loss 0.001537
Step 300         Loss 0.001194
Step 400         Loss 0.000885
Step 500         Loss 0.000627
Step 600         Loss 0.000420
Step 700         Loss 0.000255
Step 800         Loss 0.000134
Step 900         Loss 0.000053