Hello, GPT!

In this notebook, we’re going to build a transformer. In particular, we’ll see how to define attention and residual blocks in Modula.

Getting the data

First, let’s download the Shakespeare dataset. The task will be to predict the next character.

[1]:
context = 64
batch_size = 12

from data.shakespeare import load_shakespeare

data = load_shakespeare(context, batch_size)

train_loader = data["train_loader"]
val_loader = data["val_loader"]
encode = data["encode"]
decode = data["decode"]
Downloading Shakespeare dataset...
Processing Shakespeare dataset...
Length of dataset in characters: 1,115,394
Vocabulary size: 65
Train has 1,003,854 tokens
Val has 111,540 tokens
Shakespeare dataset processing complete.

Let’s peek at an example to verify the data loaded correctly!

[2]:
for inputs, targets in train_loader:
    print("Input shape:", inputs.shape)
    print("Target shape:", targets.shape)
    print("First input sequence:", inputs[0][:10], "...")
    print("First target sequence:", targets[0][:10], "...")
    print("\nDecoded input:", decode(inputs[0]))
    print("\nDecoded target:", decode(targets[0]))
    break
Input shape: (12, 64)
Target shape: (12, 64)
First input sequence: [41 53 50 42  1 40 50 53 53 42] ...
First target sequence: [53 50 42  1 40 50 53 53 42  1] ...

Decoded input: cold blood no spark of honour bides.

NORTHUMBERLAND:
Be thou a

Decoded target: old blood no spark of honour bides.

NORTHUMBERLAND:
Be thou a p

Defining the architecture

Let’s use a very small setting for our transformer so it is fast to train.

[3]:
# transformer hyperparameters

vocab_size = 65
num_heads = 4
d_embed = 128
d_query = 32
d_value = 32
num_blocks = 4
attention_scale = 1
final_scale = 1

# training hyperparameters

lr = 0.1
beta = 0.95
steps = 2001
log_interval = 10
val_interval = 100
val_iters = 20

Next up, we’ll define the attention module and residual blocks.

Attention in Modula

In Modula, we’ll define attention by stringing together several bond modules to do the parameterless computations. The roadmap is:

  • Map (batch, token, d_embed) into (batch, head, token, d_query) (and same for key and value) via Linear and SplitIntoHeads

  • Use Rotary Positional Embeddings (RoPE) on the query and the key via Rope

  • Map query and key into attention similarities of shape (batch, head, token, token) via AttentionQK

  • Use a causal mask and then softmax to create attention scores via CausalMask and Softmax

  • Use the attention scores to create output vectors via ApplyAttentionScores, then MergeHeads and Linear

The main difference to a standard transformer is that AttentionQK uses \(1/d_\text{head}\) scaling instead of the standard \(1/\sqrt{d_\text{head}}\). The reason for this is to provide Lipschitz guarantees for attention that are independent of \(d_\text{head}\). For more information on this, see Appendix B.6 of Scalable Optimization in the Modular Norm.

And here’s the implementation:

[4]:
from modula.atom import Linear
from modula.bond import SplitIntoHeads, MergeHeads, Rope, AttentionQK, CausalMask, Softmax, ApplyAttentionScores, GeLU

def Attention(num_heads, d_embed, d_query, d_value, attention_scale):
    """Multi-head attention"""

    # For keys, queries, and values we add a heads dimension. For the out projection, we remove heads.
    # Remember modules compose right-to-left, and the order is Linear(d_out, d_in)! And @ means compose.
    Q = SplitIntoHeads(num_heads) @ Linear(num_heads * d_query, d_embed)
    K = SplitIntoHeads(num_heads) @ Linear(num_heads * d_query, d_embed)
    V = SplitIntoHeads(num_heads) @ Linear(num_heads * d_value, d_embed)
    W = Linear(d_embed, num_heads * d_value) @ MergeHeads()

    # Read right-to-left: rotate (Q, K) with RoPE, apply Q @ K.T, mask, softmax (with a scale we can choose).
    AttentionScores = Softmax(attention_scale) @ CausalMask() @ AttentionQK() @ Rope(d_query) @ (Q, K)

    # Read right-to-left: apply attention scores, multiply by 1/3 to fix the sensitivity to 1, project back to d_embed.
    return W @ (1/3 * ApplyAttentionScores()) @ (V, AttentionScores)

Let’s check that the sensitivity is 1 at initialization.

[5]:
print(Attention(num_heads, d_embed, d_query, d_value, attention_scale))
CompositeModule
...consists of 4 atoms and 10 bonds
...smooth
...input sensitivity is 1.0
...contributes proportion 4 to feature learning of any supermodule

Residual blocks in Modula

To implement the rest of our transformer, the roadmap is:

  • Embed the input tokens

  • Apply residual blocks for attention and the MLP

  • Project out

All that’s left is to set up the residual blocks. In Modula, we define residual connections using a convex combination. If \(L\) is the number of residual blocks, then we use a convex combination of the identity and the block to get \(x \mapsto \frac{L-1}{L} \cdot x + \frac{1}{L} \cdot \textsf{block}(x)\). The purpose is to create a Lipschitz guarantee that is independent of the number of blocks. For more information, see Proposition 4 of Scalable Optimization in the Modular Norm.

In short, these changes enable Lipschitz guarantees on our transformer even as we scale the width and the depth!

[6]:
from modula.abstract import Identity
from modula.atom import Embed

def GPT(vocab_size, num_heads, d_embed, d_query, d_value, num_blocks, blocks_mass=5, attention_scale=1.0, final_scale=1.0):
    # Set embed to have mass 1. This controls the proportion of feature learning that it contributes to the whole network.
    embed = Embed(d_embed, vocab_size)
    embed.tare()

    # Let's create attention and MLP layers.
    att = Attention(num_heads, d_embed, d_query, d_value, attention_scale)
    mlp = Linear(d_embed, 4*d_embed) @ GeLU() @ Linear(4*d_embed, d_embed)

    # For our residual connections, L = 2*num_blocks because each block has two residual connections.
    att_block = (1-1/(2*num_blocks)) * Identity() + 1/(2*num_blocks) * att
    mlp_block = (1-1/(2*num_blocks)) * Identity() + 1/(2*num_blocks) * mlp

    # We can use powers of a module to compose it with itself many times!
    blocks = (mlp_block @ att_block) ** num_blocks

    # Set all transformer blocks to have mass 5 (by default).
    # So 5/7 of the change in the network output is due to the blocks,
    # and 2/7 of the change in output is due to the embedding and out projection.
    blocks.tare(absolute=blocks_mass)

    out = final_scale * Linear(vocab_size, d_embed)

    return out @ blocks @ embed

And finally we are ready to construct our GPT!

[7]:
model = GPT(
    vocab_size=vocab_size,
    num_heads=num_heads,
    d_embed=d_embed,
    d_query=d_query,
    d_value=d_value,
    num_blocks=num_blocks,
    attention_scale=attention_scale,
    final_scale=final_scale,
)

model.jit()

print(model)
CompositeModule
...consists of 26 atoms and 78 bonds
...non-smooth
...input sensitivity is 1.0
...contributes proportion 7.0 to feature learning of any supermodule

Loss function and training

To train our transformer we’ll use cross entropy loss, which we can compute by decomposing the softmax:

\[-\log(\text{target probability}) = -\log(\text{softmax}(\text{logits})_\text{target}) = -\text{logit}_\text{target} + \text{log\,sum\,exp}(\text{logits})\]
[8]:
import jax
import jax.numpy as jnp

def cross_entropy_loss(w, inputs, targets):
    # We use the logsumexp trick for stable cross entropy
    logits = model(inputs, w)  # shape is [batch, seq_len, vocab_size]
    batch_indices = jnp.arange(logits.shape[0])[:, None]  # shape is [batch, 1]
    seq_indices = jnp.arange(logits.shape[1])[None, :]    # shape is [1, seq_len]
    # This indexing selects out logits[b, s, targets[b, s]], which is the target logit
    losses = -logits[batch_indices, seq_indices, targets] + jax.nn.logsumexp(logits, axis=-1)  # shape is [batch, seq_len]
    return losses.mean()

loss_and_grad = jax.jit(jax.value_and_grad(cross_entropy_loss))

And we’re ready to train!

[9]:
key = jax.random.PRNGKey(0)
w = model.initialize(key)

step = 0
momentum = [0 * weight for weight in w]
lr_schedule = lambda step: lr * (steps - step) / steps
for inputs, targets in train_loader:
    loss, grad_w = loss_and_grad(w, inputs, targets)
    momentum = [beta * m + (1 - beta) * g_w for m, g_w in zip(momentum, grad_w)]
    d_w = model.dualize(momentum)
    w = [weight - lr_schedule(step) * d_weight for weight, d_weight in zip(w, d_w)]

    if step % log_interval == 0:
        print(f"Step {step}: loss {loss}")

    if step % val_interval == 0:
        val_losses = []
        for val_inputs, val_targets in val_loader:
            loss, _ = loss_and_grad(w, val_inputs, val_targets)
            val_losses.append(loss)
            if len(val_losses) >= val_iters:
                break
        print(f"--> val loss {sum(val_losses)/len(val_losses)}")

    step += 1

    if step >= steps:
        break
Step 0: loss 4.226325988769531
--> val loss 4.179544448852539
Step 10: loss 3.8738746643066406
Step 20: loss 3.3448646068573
Step 30: loss 2.805002212524414
Step 40: loss 2.68573260307312
Step 50: loss 2.6098480224609375
Step 60: loss 2.407468557357788
Step 70: loss 2.418379783630371
Step 80: loss 2.359757423400879
Step 90: loss 2.2685279846191406
Step 100: loss 2.314124584197998
--> val loss 2.541980743408203
Step 110: loss 2.283424139022827
Step 120: loss 2.2063167095184326
Step 130: loss 2.1598031520843506
Step 140: loss 2.252727508544922
Step 150: loss 2.124152660369873
Step 160: loss 2.23785662651062
Step 170: loss 2.2059123516082764
Step 180: loss 2.102996587753296
Step 190: loss 2.132392168045044
Step 200: loss 2.130244255065918
--> val loss 2.359212636947632
Step 210: loss 2.0895276069641113
Step 220: loss 2.1278815269470215
Step 230: loss 1.9647449254989624
Step 240: loss 2.1118733882904053
Step 250: loss 1.9459623098373413
Step 260: loss 2.118051290512085
Step 270: loss 2.0605385303497314
Step 280: loss 2.0378551483154297
Step 290: loss 2.0237479209899902
Step 300: loss 1.982785940170288
--> val loss 2.2887392044067383
Step 310: loss 2.073058605194092
Step 320: loss 2.082066535949707
Step 330: loss 2.130162239074707
Step 340: loss 2.092909336090088
Step 350: loss 1.9229984283447266
Step 360: loss 1.9037134647369385
Step 370: loss 2.0083131790161133
Step 380: loss 2.0236263275146484
Step 390: loss 2.0116419792175293
Step 400: loss 2.091407299041748
--> val loss 2.2199790477752686
Step 410: loss 2.0855846405029297
Step 420: loss 1.8506882190704346
Step 430: loss 1.9745848178863525
Step 440: loss 1.9135173559188843
Step 450: loss 2.0486648082733154
Step 460: loss 1.983982801437378
Step 470: loss 1.9958977699279785
Step 480: loss 1.9868993759155273
Step 490: loss 2.009216785430908
Step 500: loss 2.073169231414795
--> val loss 2.141632556915283
Step 510: loss 2.0603322982788086
Step 520: loss 2.0025858879089355
Step 530: loss 1.9482192993164062
Step 540: loss 1.9092429876327515
Step 550: loss 2.109374761581421
Step 560: loss 1.9060167074203491
Step 570: loss 1.9423940181732178
Step 580: loss 1.9405231475830078
Step 590: loss 1.9132475852966309
Step 600: loss 2.0125274658203125
--> val loss 2.2273831367492676
Step 610: loss 2.0854687690734863
Step 620: loss 1.9796791076660156
Step 630: loss 1.982351303100586
Step 640: loss 2.044363021850586
Step 650: loss 2.030698299407959
Step 660: loss 2.0731544494628906
Step 670: loss 1.9660027027130127
Step 680: loss 1.933128833770752
Step 690: loss 1.8852118253707886
Step 700: loss 1.8401598930358887
--> val loss 2.0958476066589355
Step 710: loss 1.9790323972702026
Step 720: loss 2.0329394340515137
Step 730: loss 1.929424524307251
Step 740: loss 1.950282335281372
Step 750: loss 1.938680648803711
Step 760: loss 1.9717748165130615
Step 770: loss 1.8411779403686523
Step 780: loss 2.085500717163086
Step 790: loss 1.8778104782104492
Step 800: loss 1.9712986946105957
--> val loss 2.1469686031341553
Step 810: loss 1.949462652206421
Step 820: loss 1.9898126125335693
Step 830: loss 1.9045312404632568
Step 840: loss 1.9053363800048828
Step 850: loss 1.8944416046142578
Step 860: loss 1.8389015197753906
Step 870: loss 1.9189136028289795
Step 880: loss 2.0141639709472656
Step 890: loss 1.9987534284591675
Step 900: loss 1.947631597518921
--> val loss 2.1903281211853027
Step 910: loss 2.031083106994629
Step 920: loss 1.988853931427002
Step 930: loss 2.0356318950653076
Step 940: loss 1.8823192119598389
Step 950: loss 2.0429515838623047
Step 960: loss 2.021817684173584
Step 970: loss 2.003168821334839
Step 980: loss 2.0105528831481934
Step 990: loss 2.014195680618286
Step 1000: loss 1.9518741369247437
--> val loss 2.0813283920288086
Step 1010: loss 2.016996383666992
Step 1020: loss 2.04374098777771
Step 1030: loss 1.8839387893676758
Step 1040: loss 1.96620512008667
Step 1050: loss 2.0463950634002686
Step 1060: loss 1.9169645309448242
Step 1070: loss 2.038651943206787
Step 1080: loss 2.0474071502685547
Step 1090: loss 1.9452462196350098
Step 1100: loss 1.8884999752044678
--> val loss 2.1541106700897217
Step 1110: loss 1.9775495529174805
Step 1120: loss 1.96068274974823
Step 1130: loss 1.8553755283355713
Step 1140: loss 1.9422013759613037
Step 1150: loss 2.0833449363708496
Step 1160: loss 1.840619444847107
Step 1170: loss 2.032219409942627
Step 1180: loss 1.9345749616622925
Step 1190: loss 1.934565544128418
Step 1200: loss 1.9528722763061523
--> val loss 2.1688506603240967
Step 1210: loss 1.8676490783691406
Step 1220: loss 1.9311145544052124
Step 1230: loss 1.9905321598052979
Step 1240: loss 1.8773740530014038
Step 1250: loss 1.9832658767700195
Step 1260: loss 1.8256521224975586
Step 1270: loss 2.037313461303711
Step 1280: loss 1.9440114498138428
Step 1290: loss 1.9472723007202148
Step 1300: loss 1.862718105316162
--> val loss 2.0632894039154053
Step 1310: loss 1.944453239440918
Step 1320: loss 1.869157075881958
Step 1330: loss 1.9843480587005615
Step 1340: loss 1.9083728790283203
Step 1350: loss 1.920233130455017
Step 1360: loss 1.7926225662231445
Step 1370: loss 1.8765363693237305
Step 1380: loss 1.9374698400497437
Step 1390: loss 1.9032771587371826
Step 1400: loss 1.8976068496704102
--> val loss 2.0361690521240234
Step 1410: loss 1.8799960613250732
Step 1420: loss 1.9112414121627808
Step 1430: loss 1.8797309398651123
Step 1440: loss 1.9040837287902832
Step 1450: loss 1.8828296661376953
Step 1460: loss 1.83419930934906
Step 1470: loss 1.8327134847640991
Step 1480: loss 1.857541799545288
Step 1490: loss 1.8209788799285889
Step 1500: loss 1.780470371246338
--> val loss 2.0466208457946777
Step 1510: loss 1.8544996976852417
Step 1520: loss 1.8710064888000488
Step 1530: loss 1.8195044994354248
Step 1540: loss 1.874974250793457
Step 1550: loss 1.7101812362670898
Step 1560: loss 1.8439801931381226
Step 1570: loss 1.967679500579834
Step 1580: loss 1.888682246208191
Step 1590: loss 1.6926288604736328
Step 1600: loss 1.875901222229004
--> val loss 2.044935941696167
Step 1610: loss 1.8210939168930054
Step 1620: loss 1.7439773082733154
Step 1630: loss 1.7956527471542358
Step 1640: loss 1.792572021484375
Step 1650: loss 1.7985519170761108
Step 1660: loss 1.8520288467407227
Step 1670: loss 1.680544137954712
Step 1680: loss 1.7917392253875732
Step 1690: loss 1.8400462865829468
Step 1700: loss 1.6793416738510132
--> val loss 1.995697021484375
Step 1710: loss 1.7414367198944092
Step 1720: loss 1.8606326580047607
Step 1730: loss 1.7578084468841553
Step 1740: loss 1.6292760372161865
Step 1750: loss 1.7017428874969482
Step 1760: loss 1.8407533168792725
Step 1770: loss 1.7789411544799805
Step 1780: loss 1.802499532699585
Step 1790: loss 1.7586851119995117
Step 1800: loss 1.7281568050384521
--> val loss 1.9875770807266235
Step 1810: loss 1.7767337560653687
Step 1820: loss 1.7158925533294678
Step 1830: loss 1.7596324682235718
Step 1840: loss 1.7826766967773438
Step 1850: loss 1.7769875526428223
Step 1860: loss 1.6953961849212646
Step 1870: loss 1.7714271545410156
Step 1880: loss 1.6994340419769287
Step 1890: loss 1.7252253293991089
Step 1900: loss 1.566367506980896
--> val loss 1.9310436248779297
Step 1910: loss 1.7057380676269531
Step 1920: loss 1.7441104650497437
Step 1930: loss 1.7951183319091797
Step 1940: loss 1.8611491918563843
Step 1950: loss 1.787139654159546
Step 1960: loss 1.788725733757019
Step 1970: loss 1.7919573783874512
Step 1980: loss 1.706597089767456
Step 1990: loss 1.771501898765564
Step 2000: loss 1.7121562957763672
--> val loss 1.8968441486358643

Though this be madness, yet there is method in’t

And indeed, let us look at how our wee model stacks up to the master.

[10]:
def generate_text(prompt, max_tokens=100, temperature=0.5, seed=0):
    key = jax.random.PRNGKey(seed)
    tokens = jnp.array(encode(prompt))
    for _ in range(max_tokens):
        logits = model(jnp.expand_dims(tokens, 0), w)
        next_token_logits = logits[0, -1] / temperature

        # Sample from our model's token distribution
        key, subkey = jax.random.split(key)
        next_token = jax.random.categorical(subkey, next_token_logits)
        tokens = jnp.append(tokens, next_token)

    return decode(tokens)

for seed in range(3):
    print(f"Sample {seed}:\n\n{generate_text('If', max_tokens=100, seed=seed)}")
    print("-" * 80)
Sample 0:

If where his elperiend and is here in think the comfore be pray virtue deather I the grouth a pears my
--------------------------------------------------------------------------------
Sample 1:

If as the conture the weet to the man's death the greeen he with thought rame the prosates he palousen
--------------------------------------------------------------------------------
Sample 2:

If him the be not me were and let for the earth the forth,
That the his a wort of you the fearshould a
--------------------------------------------------------------------------------