Basic Transformer Architecture Notes

Intro

Here are some notes on the basic transformer architecture for my personal learning and understanding. Useful as a secondary resource, not the first stop. There are many resources out there, but here are several I enjoyed learning from:

Tokenization and Input Embeddings

In diagrams and code comments I will use the symbols:

from math import sqrt

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
vocab_size = tokenizer.vocab_size  # will also denote as V
seq_length = 16  # will also denote sequence length as T i.e. "time dimension"
embed_dim = 64  # will also denote as C i.e. "channel dimension"
num_heads = 8
head_dim = embed_dim // num_heads  # will also denote as H
Using device: cuda

texts = [
    "I love summer",
    "I love tacos",
]
inputs = tokenizer(
    texts,
    return_tensors="pt",
    padding="max_length",
    max_length=seq_length,
    truncation=True,
).input_ids
inputs
tensor([[  101,  1045,  2293,  2621,   102,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0],
        [  101,  1045,  2293, 11937, 13186,   102,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0]])
print(inputs.shape)  # (B, T)
print(vocab_size)
torch.Size([2, 16])
30522
for row in inputs:
    print(tokenizer.convert_ids_to_tokens(row))
['[CLS]', 'i', 'love', 'summer', '[SEP]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]']
['[CLS]', 'i', 'love', 'ta', '##cos', '[SEP]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]']

Now that the text is tokenized we can look up the token embeddings. Here is the look-up token embedding table:

token_emb = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embed_dim)
token_emb  # (V, C)
Embedding(30522, 64)

Get the token embeddings for the batch of inputs:

token_embeddings = token_emb(inputs)
token_embeddings.shape  # (B, T, C)
torch.Size([2, 16, 64])

There are various methods for positional embeddings, but here is a very simple approach.

positional_emb = nn.Embedding(num_embeddings=seq_length, embedding_dim=embed_dim)  # (T, C)
positional_embeddings = positional_emb(torch.arange(start=0, end=seq_length, step=1))
positional_embeddings.shape  # (T, C)
torch.Size([16, 64])

Using broadcasting, we can add the two embeddings (token and positional) to get the final input embeddings.

token_embeddings.shape  # (B, T, C)
torch.Size([2, 16, 64])
embeddings = token_embeddings + positional_embeddings
embeddings.shape
torch.Size([2, 16, 64])

Self Attention

We begin with our input embeddings:

# our embeddings input: (B, T, C)
embeddings.shape
torch.Size([2, 16, 64])
query = nn.Linear(in_features=embed_dim, out_features=head_dim, bias=False)
key = nn.Linear(in_features=embed_dim, out_features=head_dim, bias=False)
value = nn.Linear(in_features=embed_dim, out_features=head_dim, bias=False)

# projections of the original embeddings
q = query(embeddings)  # (B, T, head_dim)
k = key(embeddings)  # (B, T, head_dim)
v = value(embeddings)  # (B, T, head_dim)
q.shape, k.shape, v.shape
(torch.Size([2, 16, 8]), torch.Size([2, 16, 8]), torch.Size([2, 16, 8]))
w = (q @ k.transpose(-2, -1)) / sqrt(head_dim)  # (B, T, T) gives the scores between all the token embeddings within each batch
# optional mask
tril = torch.tril(torch.ones(seq_length, seq_length))
w = w.masked_fill(tril == 0, float("-inf"))
# normalize weights
w = F.softmax(w, dim=-1)  # (B, T, T)
w.shape
torch.Size([2, 16, 16])
# weighted average (linear combination) of the projected input embeddings
out = w @ v
out.shape
torch.Size([2, 16, 8])

Multi Head Attention

Feed forward layer (FFN)

The only thing we have not mentioned is the use of Layer Normalization and Skip connections. These are typical tricks to improve training of networks. It will become more clear how they are used in the next section when we put it all together in the code.

Putting it all Together

from math import sqrt

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from datasets import load_dataset
from tqdm import tqdm
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")


class Config:
    vocab_size = tokenizer.vocab_size
    seq_length = 128  # will also denote as T i.e. "time dimension"
    batch_size = 256  # will also denote as B
    embed_dim = 64  # will also denote as C i.e. "channel dimension"
    num_heads = 4
    head_dim = embed_dim // num_heads  #  will also denote as H
    dropout_prob = 0.0
    num_transformer_layers = 4


class Embeddings(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.token_embd = nn.Embedding(config.vocab_size, config.embed_dim)
        self.pos_embd = nn.Embedding(config.seq_length, config.embed_dim)
        self.dropout = nn.Dropout(config.dropout_prob)
        self.layer_norm = nn.LayerNorm(config.embed_dim)

    def forward(self, x):
        # x is B,T --> the tensor of token input_ids
        seq_length = x.size(-1)
        token_embeddings = self.token_embd(x)  # (B, T, C)
        positional_embeddings = self.pos_embd(torch.arange(start=0, end=seq_length, step=1, device=device))  # (T, C)
        x = token_embeddings + positional_embeddings  # (B, T, C)
        x = self.layer_norm(x)  # (B, T, C)
        x = self.dropout(x)  # (B, T, C)
        return x


class AttentionHead(nn.Module):
    def __init__(self, config, mask=True):
        super().__init__()
        self.mask = mask
        self.query = nn.Linear(config.embed_dim, config.head_dim, bias=False)
        self.key = nn.Linear(config.embed_dim, config.head_dim, bias=False)
        self.value = nn.Linear(config.embed_dim, config.head_dim, bias=False)
        self.register_buffer("tril", torch.tril(torch.ones(config.seq_length, config.seq_length)))
        self.dropout = nn.Dropout(config.dropout_prob)

    def forward(self, x):
        # x is (B, T, C)
        b, t, c = x.shape
        q = self.query(x)  # (B, T, H)
        k = self.key(x)  # (B, T, H)
        v = self.value(x)  # (B, T, H)

        dim_k = k.shape[-1]  # i.e. head dimension
        w = q @ k.transpose(-2, -1) / sqrt(dim_k)  # (B, T, T)
        if self.mask:
            w = w.masked_fill(self.tril[:t, :t] == 0, float("-inf"))  # (B, T, T)
        w = F.softmax(w, dim=-1)  # (B, T, T)
        w = self.dropout(w)  # good for regularization
        out = w @ v  # (B, T, H)
        return out


class MultiHeadAttention(nn.Module):
    def __init__(self, config, mask=True):
        super().__init__()
        self.attention_heads = nn.ModuleList([AttentionHead(config, mask) for _ in range(config.num_heads)])
        self.linear_proj = nn.Linear(config.embed_dim, config.embed_dim)

    def forward(self, x):
        # each input tensor x has shape (B, T, C)
        # each attention head, head(x) is of shape (B, T, H)
        # concat these along the last dimension to get (B, T, C)
        x = torch.concat([head(x) for head in self.attention_heads], dim=-1)
        return self.linear_proj(x)  # (B, T, C)


class FeedForwardNetwork(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.layer1 = nn.Linear(config.embed_dim, 4 * config.embed_dim)
        self.layer2 = nn.Linear(4 * config.embed_dim, config.embed_dim)
        self.gelu = nn.GELU()
        self.dropout = nn.Dropout(config.dropout_prob)

    def forward(self, x):
        # x is (B, T, C)
        x = self.layer1(x)  # (B, T, 4C)
        x = self.gelu(x)  # (B, T, 4C)
        x = self.layer2(x)  # (B, T, C)
        x = self.dropout(x)  # (B, T, C)
        return x


class TransformerBlock(nn.Module):
    def __init__(self, config, mask=True):
        super().__init__()
        self.mha = MultiHeadAttention(config, mask)
        self.ffn = FeedForwardNetwork(config)
        self.layer_norm_1 = nn.LayerNorm(config.embed_dim)
        self.layer_norm_2 = nn.LayerNorm(config.embed_dim)

    def forward(self, x):
        # x is (B, T, C)
        x = x + self.mha(self.layer_norm_1(x))  # (B, T, C)
        x = x + self.ffn(self.layer_norm_2(x))  # (B, T, C)
        return x


class Transformer(nn.Module):
    def __init__(self, config, mask=True):
        super().__init__()
        self.embeddings = Embeddings(config)
        self.transformer_layers = nn.ModuleList([TransformerBlock(config, mask) for _ in range(config.num_transformer_layers)])

    def forward(self, x):
        # x is shape (B, T). It is the output from a tokenizer
        x = self.embeddings(x)  # (B, T, C)
        for layer in self.transformer_layers:
            x = layer(x)  # (B, T, C)
        return x  # (B, T, C)
Transformer(Config).to(device)(inputs.to(device)).shape
torch.Size([2, 16, 64])

Training Decoder For Next Token Prediction

dataset = load_dataset("roneneldan/TinyStories")["train"]
dataset = dataset.select(range(500000))  # decrease/increase to fewer data points to speed up training
def tokenize(element):
    # Increase max_length by 1 to get the next token
    outputs = tokenizer(
        element["text"],
        truncation=True,
        max_length=Config.seq_length + 1,
        padding="max_length",
        return_overflowing_tokens=True,
        return_length=True,
        add_special_tokens=False,
    )
    input_batch = []
    target_batch = []
    for length, input_ids in zip(outputs["length"], outputs["input_ids"]):
        if length == Config.seq_length + 1:
            input_batch.append(input_ids[:-1])  # Exclude the last token for input
            target_batch.append(input_ids[1:])  # Exclude the first token for target
    return {"input_ids": input_batch, "labels": target_batch}


tokenized_datasets = dataset.map(tokenize, batched=True, remove_columns=dataset.column_names, num_proc=8)
print(tokenizer.convert_ids_to_tokens(tokenized_datasets[0]["input_ids"]))
print(tokenizer.convert_ids_to_tokens(tokenized_datasets[0]["labels"]))
['one', 'day', ',', 'a', 'little', 'girl', 'named', 'lily', 'found', 'a', 'needle', 'in', 'her', 'room', '.', 'she', 'knew', 'it', 'was', 'difficult', 'to', 'play', 'with', 'it', 'because', 'it', 'was', 'sharp', '.', 'lily', 'wanted', 'to', 'share', 'the', 'needle', 'with', 'her', 'mom', ',', 'so', 'she', 'could', 'se', '##w', 'a', 'button', 'on', 'her', 'shirt', '.', 'lily', 'went', 'to', 'her', 'mom', 'and', 'said', ',', '"', 'mom', ',', 'i', 'found', 'this', 'needle', '.', 'can', 'you', 'share', 'it', 'with', 'me', 'and', 'se', '##w', 'my', 'shirt', '?', '"', 'her', 'mom', 'smiled', 'and', 'said', ',', '"', 'yes', ',', 'lily', ',', 'we', 'can', 'share', 'the', 'needle', 'and', 'fix', 'your', 'shirt', '.', '"', 'together', ',', 'they', 'shared', 'the', 'needle', 'and', 'se', '##wed', 'the', 'button', 'on', 'lily', "'", 's', 'shirt', '.', 'it', 'was', 'not', 'difficult', 'for', 'them', 'because', 'they', 'were', 'sharing']
['day', ',', 'a', 'little', 'girl', 'named', 'lily', 'found', 'a', 'needle', 'in', 'her', 'room', '.', 'she', 'knew', 'it', 'was', 'difficult', 'to', 'play', 'with', 'it', 'because', 'it', 'was', 'sharp', '.', 'lily', 'wanted', 'to', 'share', 'the', 'needle', 'with', 'her', 'mom', ',', 'so', 'she', 'could', 'se', '##w', 'a', 'button', 'on', 'her', 'shirt', '.', 'lily', 'went', 'to', 'her', 'mom', 'and', 'said', ',', '"', 'mom', ',', 'i', 'found', 'this', 'needle', '.', 'can', 'you', 'share', 'it', 'with', 'me', 'and', 'se', '##w', 'my', 'shirt', '?', '"', 'her', 'mom', 'smiled', 'and', 'said', ',', '"', 'yes', ',', 'lily', ',', 'we', 'can', 'share', 'the', 'needle', 'and', 'fix', 'your', 'shirt', '.', '"', 'together', ',', 'they', 'shared', 'the', 'needle', 'and', 'se', '##wed', 'the', 'button', 'on', 'lily', "'", 's', 'shirt', '.', 'it', 'was', 'not', 'difficult', 'for', 'them', 'because', 'they', 'were', 'sharing', 'and']
class LanguageModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.transformer = Transformer(config, mask=True)
        self.classifier = nn.Linear(config.embed_dim, tokenizer.vocab_size)

    def forward(self, x):
        # x is (B, T) the token ids
        x = self.transformer(x)  # (B, T, C)
        logits = self.classifier(x)  # (B, T, V)
        return logits
model = LanguageModel(Config).to(device)
print(sum(p.numel() for p in model.parameters()) / 1e6, "M parameters")
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
4.144826 M parameters
def generate_text(prompt, max_tokens=100):
    inputs = tokenizer(
        [prompt],
        truncation=True,
        max_length=Config.seq_length,
        add_special_tokens=False,
    )
    inputs = torch.tensor(inputs.input_ids).to(device)

    for i in range(max_tokens):
        logits = model(inputs)  # (B, T, V)
        # convert logits to probabilities and only consider the probabilities for the last token in the sequence i.e. predict next token
        probs = logits[:, -1, :].softmax(dim=-1)
        # sample a token from the distribution over the vocabulary
        idx_next = torch.multinomial(probs, num_samples=1)
        inputs = torch.cat([inputs, idx_next], dim=1)
    return tokenizer.decode(inputs[0])

Since we have not trained the model yet this output should be complete random garbage tokens.

print(generate_text("Once upon"))
once upon align tiltlm nicole speech quintet outsideials 1833 1785 asteroid exit jim caroline 19th 分 tomatoआ mt joanne ball busted hear hears neighbourhoods twitterouringbis maoma 貝 oven williams [unused646] presidential [unused618] [unused455]版tish gavin accountability stanford materials chung avoids unstable hyde culinary گ catalonia versatile gradient gross geography porn justice contributes deposition robotics 00pm showcased current laying b aixroudzko rooney abrahamhedron sideways postseason grossed conviction overheard crowley said warehouses heights times arising 80 reeve deptrned noelle fingered pleistocene pushed rock buddhist [unused650] brunette nailed upstream [unused86] ufc bolts鈴 grounds

for epoch in range(1):
    train_loss = []
    loop = tqdm(range(0, len(tokenized_datasets), Config.batch_size))
    for i in loop:
        x = torch.tensor(tokenized_datasets[i : i + Config.batch_size]["input_ids"]).to(device)  # (B, T)
        target = torch.tensor(tokenized_datasets[i : i + Config.batch_size]["labels"]).to(device)  # (B, T)
        logits = model(x)  # (B, T, V)

        b, t, v = logits.size()
        logits = logits.view(b * t, v)  # (B*T, V)
        target = target.view(b * t)  # B*T
        loss = F.cross_entropy(logits, target)
        train_loss.append(loss.item())
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()
        if i % 1000 == 0:
            avg_loss = np.mean(train_loss)
            train_loss = []
            loop.set_description(f"Epoch {epoch}, Avg Loss: {avg_loss:.4f}")
Epoch 0, Avg Loss: 2.1089: 100%|██████████| 4357/4357 [13:41<00:00,  5.30it/s]
print(generate_text("Once upon"))
once upon a time, there was a little girl named maggie who loved watching movies in her neighborhood. today, ellie decided to do it was a sunny day. her mommy asked, " what does it means you fill the start? " her mommy the thief played with her very look at it. " the cop replied, " i want to get dessert together! " sarah watched together and took some of her pencils and sugar. mommy took a look so carefully together. they played together until it had something shiny

Not quite GPT-4 performance, lol!