The Evolution of Transformers: From BERT to Models with Billions of Parameters

The Evolution of Transformers: From BERT to Models with Billions of Parameters
Play this article

Transformers have become the dominant model architecture across natural language processing, computer vision, and other domains in recent years. In this blog post, I'll dive into some of the major innovations that have allowed Transformers like BERT, GPT-3, and others to scale to unprecedented sizes and capabilities.

The Rise of Attention Mechanisms

Transformers were first introduced in the 2017 paper "Attention is All You Need". The key innovation was removing recurrence and convolutions and instead relying entirely on an attention mechanism to model dependencies.

The Transformer uses multi-headed self-attention where each token attends over all other tokens and combines several representations. This allows it to learn contextual relationships in parallel:

import torch
import torch.nn as nn

class MultiHeadedSelfAttention(nn.Module):
    def __init__(self, hid_dim, n_heads):
        self.n_heads = n_heads
        self.head_dim = hid_dim // n_heads

        self.fc_q = nn.Linear(hid_dim, hid_dim)
        self.fc_k = nn.Linear(hid_dim, hid_dim)
        self.fc_v = nn.Linear(hid_dim, hid_dim)

        self.fc_o = nn.Linear(hid_dim, hid_dim)

    def forward(self, query, key, value, mask=None):
        batch_size = query.shape[0]

        # Linear projections    
        Q = self.fc_q(query) # (batch_size, query_len, hid_dim)        
        K = self.fc_k(key) # (batch_size, key_len, hid_dim)
        V = self.fc_v(value) # (batch_size, value_len, hid_dim)

        # Split into multiple heads
        Q = Q.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        K = K.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3) 
        V = V.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)

        # Attention weights
        attn = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.head_dim**0.5

        if mask is not None:
            attn = attn.masked_fill(mask==0, -1e9)

        attn = torch.softmax(attn, dim=-1)

        # Attending to values  
        x = torch.matmul(attn, V)

        # Concatenate heads    
        x = x.permute(0, 2, 1, 3).contiguous()
        x = x.view(batch_size, -1, self.hid_dim)

        x = self.fc_o(x)

        return x

This formed the foundation for models like BERT and beyond to model language effectively.

Bigger and Bigger Pretrained Models

After showing the power of attention, there was a race to scale up Transformers. OpenAI's GPT-2 in 2019 showed generating coherent paragraphs of text was possible. Soon after, models like BERT and XLNet advanced the state-of-the-art across NLP tasks using bidirectional pretraining objectives.

These models were still relatively small, with hundreds of millions of parameters. But in 2020, GPT-3 demonstrated the benefits of massive scaling, using attention layers to build general knowledge across 175 billion parameters!

There was one problem though - naively scaling up Transformers resulted in quadratic growth of computational and memory costs due to the dot-product self-attention. The next phase of innovations aimed to address these limitations.

Making Attention More Efficient

Several methods have been introduced to make attention mechanisms more practical in massive Transformers:

Sparse attention fixes a small set of global tokens to attend to rather than the full sequence. This converts soft attention into something closer to hard attention:

# Sparse attention 
nb_global_tokens = 64

Q = queries # (batch_size, query_len, hid_dim)
K = keys # (batch_size, key_len, hid_dim) 

# Keep a fixed small set of global tokens
K_global = K[:,:nb_global_tokens]  

# Local attention
A_local = torch.matmul(Q, K.permute(0,2,1))

# Global attention 
A_global = torch.matmul(Q, K_global.permute(0,2,1))
A = A_local + A_global

Reformer uses locally sensitive hashing to group similar tokens and reduces the sequence length.

Longformer applies attention only on a local context window while using global attention on special prompt-like tokens.

These methods allow quadratic costs to be reduced to linear, enabling scaling to trillions of parameters.

Mixture of Experts (MoE)

Another approach is to break up Transformer layers into multiple smaller expert networks. For example, tokens can be routed to different experts specializing in local or global content:

class MoE(nn.Module):

    def __init__(self, hid_dim, experts):
        self.experts = nn.ModuleList([nn.Linear(hid_dim, hid_dim)
                                      for _ in experts])

        self.routing_fn = nn.Linear(hid_dim, experts)

    def forward(self, x):

        out =[exp(x) for exp in self.experts], dim=-1)
        gates = torch.softmax(self.routing_fn(x), dim=1)

        # Element-wise product 
        out = torch.sum(out * gates, dim=1)

        return out

This provides an efficient way to increase model capacity while keeping each expert manageable in size.

Scaling Laws

As models grow, it's been shown Transformers exhibit a scaling law relating parameters, compute, and sample efficiency.

Models like GPT-3 illustrated that given sufficient data, Transformers continue to benefit from scaling up through hundreds of billions of parameters resulting in more general competencies.

Understanding these scaling laws provides insights into how much room left exists for future progress.

The Path Forward

In just a few short years, Transformers have quickly become the premier architecture for NLP and beyond. Advancements in attention, model parallelism, sparse expert models, and scaling laws point toward AI systems with ever-greater reasoning and knowledge capacities. I'm excited to see these models continue to evolve and unlock new capabilities as they scale further in 2023 and beyond!

Cover Image:

Did you find this article valuable?

Support Kaan Berke UGURLAR by becoming a sponsor. Any amount is appreciated!