Hello, World!¶
On this page, we will build a simple training loop to fit an MLP to some randomly generated data. We start by sampling some data. Modula uses JAX to handle array computations, so we use JAX to sample the data. JAX requires us to explicitly pass in the state of the random number generator.
[8]:
import jax
import jax.numpy as jnp
input_dim = 784
output_dim = 10
batch_size = 128
key = jax.random.PRNGKey(0)
inputs = jax.random.normal(key, (batch_size, input_dim))
targets = jax.random.normal(key, (batch_size, output_dim))
Next, we will build our neural network. We import the basic Linear and ReLU modules. And we compose them by using the @
operator. Calling mlp.jit()
tries to make all the internal module methods more efficient using just-in-time compilation from JAX.
[9]:
from modula.atom import Linear
from modula.bond import ReLU
width = 256
mlp = Linear(output_dim, width)
mlp @= ReLU()
mlp @= Linear(width, width)
mlp @= ReLU()
mlp @= Linear(width, input_dim)
print(mlp)
mlp.jit()
CompositeModule
...consists of 3 atoms and 2 bonds
...non-smooth
...input sensitivity is 1
...contributes proportion 3 to feature learning of any supermodule
Next, we set up a loss function and create a jitted function for both evaluating the loss and also returning its gradient.
[10]:
def mse(w, inputs, targets):
outputs = mlp(inputs, w)
loss = ((outputs-targets) ** 2).mean()
return loss
mse_and_grad = jax.jit(jax.value_and_grad(mse))
Finally we are ready to train our model. We will apply the method mlp.dualize
to the gradient of the loss to solve for the vector of unit modular norm that maximizes the linearized improvement in loss.
[12]:
steps = 1000
learning_rate = 0.1
key = jax.random.PRNGKey(0)
w = mlp.initialize(key)
for step in range(steps):
# compute loss and gradient of weights
loss, grad_w = mse_and_grad(w, inputs, targets)
# dualize gradient
d_w = mlp.dualize(grad_w)
# compute scheduled learning rate
lr = learning_rate * (1 - step / steps)
# update weights
w = [weight - lr * d_weight for weight, d_weight in zip(w, d_w)]
if step % 100 == 0:
print(f"Step {step:3d} \t Loss {loss:.6f}")
Step 0 Loss 0.979311
Step 100 Loss 0.001822
Step 200 Loss 0.001423
Step 300 Loss 0.001066
Step 400 Loss 0.000766
Step 500 Loss 0.000519
Step 600 Loss 0.000340
Step 700 Loss 0.000196
Step 800 Loss 0.000090
Step 900 Loss 0.000025