Quickstart

Modula is a neural networks library built on top of JAX.

Installation

Modula can be installed using pip:

pip install git+https://github.com/modula-systems/modula.git

Or you can clone the repository and install locally:

git clone https://github.com/modula-systems/modula.git
cd modula
pip install -e .

Functionality

Modula provides a set of architecture-specific helper functions that are automatically constructed along with the network architecture itself. As an example, let’s build a multi-layer perceptron:

from modula.atom import Linear
from modula.bond import ReLU

mlp = Linear(10, 256)
mlp @= ReLU()
mlp @= Linear(256, 256)
mlp @= ReLU()
mlp @= Linear(256, 784)

mlp.jit() # makes everything run faster

Behind the scenes, Modula builds a function to randomly initialize the weights of the network:

import jax

key = jax.random.PRNGKey(0)
weights = mlp.initialize(key)

Supposing we have used JAX to compute the gradient of our loss and stored this as grad, then we can use Modula to dualize the gradient, thereby accelerating our gradient descent training:

dualized_grad = mlp.dualize(grad)
weights = [w - 0.1 * dg for w, dg in zip(weights, dualized_grad)]

And after the weight update, we can project the weights back to their natural constraint set:

weights = mlp.project(weights)

In short, Modula lets us think about the weight space of our neural network as a somewhat classical optimization space, complete with duality and projection operations.