Hello, World!

On this page, we will build a simple training loop to fit an MLP to some randomly generated data. We start by sampling some data. Modula uses JAX to handle array computations, so we use JAX to sample the data. JAX requires us to explicitly pass in the state of the random number generator.

[8]:
import jax
import jax.numpy as jnp

input_dim = 784
output_dim = 10
batch_size = 128

key = jax.random.PRNGKey(0)
inputs = jax.random.normal(key, (batch_size, input_dim))
targets = jax.random.normal(key, (batch_size, output_dim))

Next, we will build our neural network. We import the basic Linear and ReLU modules. And we compose them by using the @ operator. Calling mlp.jit() tries to make all the internal module methods more efficient using just-in-time compilation from JAX.

[9]:
from modula.atom import Linear
from modula.bond import ReLU

width = 256

mlp = Linear(output_dim, width)
mlp @= ReLU()
mlp @= Linear(width, width)
mlp @= ReLU()
mlp @= Linear(width, input_dim)

print(mlp)

mlp.jit()
CompositeModule
...consists of 3 atoms and 2 bonds
...non-smooth
...input sensitivity is 1
...contributes proportion 3 to feature learning of any supermodule

Next, we set up a loss function and create a jitted function for both evaluating the loss and also returning its gradient.

[10]:
def mse(w, inputs, targets):
    outputs = mlp(inputs, w)
    loss = ((outputs-targets) ** 2).mean()
    return loss

mse_and_grad = jax.jit(jax.value_and_grad(mse))

Finally we are ready to train our model. We will apply the method mlp.dualize to the gradient of the loss to solve for the vector of unit modular norm that maximizes the linearized improvement in loss.

[12]:
steps = 1000
learning_rate = 0.1

key = jax.random.PRNGKey(0)
w = mlp.initialize(key)

for step in range(steps):

    # compute loss and gradient of weights
    loss, grad_w = mse_and_grad(w, inputs, targets)

    # dualize gradient
    d_w = mlp.dualize(grad_w)

    # compute scheduled learning rate
    lr = learning_rate * (1 - step / steps)

    # update weights
    w = [weight - lr * d_weight for weight, d_weight in zip(w, d_w)]

    if step % 100 == 0:
        print(f"Step {step:3d} \t Loss {loss:.6f}")

Step   0         Loss 0.979311
Step 100         Loss 0.001822
Step 200         Loss 0.001423
Step 300         Loss 0.001066
Step 400         Loss 0.000766
Step 500         Loss 0.000519
Step 600         Loss 0.000340
Step 700         Loss 0.000196
Step 800         Loss 0.000090
Step 900         Loss 0.000025