import torch
import torch.nn as nn
import torch.nn.functional as F
import time
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class SingleHeadVanillaAttention(nn.Module):
def __init__(self, d_model, d_k):
super().__init__()
self.d_model = d_model
self.d_k = d_k
self.Wq = nn.Linear(in_features = self.d_model, out_features = self.d_k)
self.Wk = nn.Linear(in_features = self.d_model, out_features = self.d_k)
self.Wv = nn.Linear(in_features = self.d_model, out_features = self.d_k)
def forward(self, x):
query = self.Wq(x) # [B, N, d_model] * [d_model, d_k] -> [B, N, d_k] TC=O(B*N*d_model*d_k)
key = self.Wk(x) #[B, N, d_model] * [d_model, d_k] -> [B, N, d_k] TC=O(B*N*d_model*d_k)
value = self.Wv(x) # [B, N, d_model] * [d_model, d_k] -> [B, N, d_k] TC=O(B*N*d_model*d_k)
scores = torch.matmul(query, key.transpose(-1, -2))/(self.d_k**0.5) # [B, N, d_k] * [B, d_k, N] -> [B, N, N] TC=O(B*N^2*d_k)
attn = F.softmax(scores, dim=-1) # [B, N, N]
final = torch.matmul(attn, value) # [B, N, N] * [B, N, d_k] -> [B, N, d_k] TC=O(B*N^2*d_k)
return final # [B, N, d_k]
d_model = 512
d_k = 512
model = SingleHeadVanillaAttention(d_model = d_model, d_k = d_k).requires_grad_(False).eval().to(device)
n_tokens = 16384
inp = torch.randn(1, n_tokens, d_model, device=device)
torch.cuda.synchronize()
start = time.time()
k=50
for _ in range(k):
with torch.no_grad():
out = model(inp)
torch.cuda.synchronize()
print(f"Time taken: {(time.time()-start)/k:.3f}")Attention Mechanism
Motivation
Issue with RNNs
- Information Bottleneck:- If sentence is too much long, they needed to compress all the previous token information in just a single vector h
- Sequential Nature:- All the tokens being processed one by one
Quotient rule
- Will be used while finding out the derivative of softmax
- \left(\frac{u}{v}\right)' = \frac{u'v - uv'}{v^2}.
Scaled Dot-Product Attention
Transformers
- Given a sequence of length N, where each token is represented by d_model numbers, we compute three matrices
- Query = XW_q (what I’m looking for)
- Key = XW_k (what I have)
- Value = XW_v (the information I’ll give if I’m relevant)
- W_q and W_k are of same shape and to maintain the symmetry W_v is also kept to be of same shape, W_q, W_k [d_{model}, d_k] and W_v [d_{model}, d_k]
- Now to find out which key is more relevant to given queries we do matrix multiplication and the finally apply softmax to get probability
- score = QueryKey^T
- attn = softmax(\frac{score}{\sqrt{d_k}})
- Why scaling factor
- this score dot product has variance of d_k, which causes instability in backpropagation
- let’s say we have two vectors q and k of dimension d_k
- \mathbb{E}[q_i] = \mathbb{E}[k_i] = 0
- var[q_i]=var[k_i] = 1
- then \mathbb{E}[q.k] = 0 but var[q.k] = d_k link
- the derivative of softmax contains some terms which if not normalized lead to gradient nearly 0 because of this larger variance
- \text{Softmax}(z_i) = \frac{e^{z_i}}{\sum_j e^{z_j}}
- \frac{\partial \text{Softmax}(z)_i}{\partial z_j} = \text{Softmax}(z)_i (\delta_{ij} - \text{Softmax}(z)_j)
- If z_i is much larger than other elements, \text{Softmax}(z)_i \approx 1.
- If z_i is much smaller, \text{Softmax}(z)_i \approx 0.
- So in case of jacobian matrix, considering a simple case where query 0 attends all the keys
- Diagonal: s_1(1-s_1) \approx 1(0) = 0.
- Off-Diagonal: -s_1 s_2 \approx -1(0) = 0.
- to enforce unit variance, we apply normalization by dividing the score by standard deviation i.e. \sqrt{d_k}
- this score dot product has variance of d_k, which causes instability in backpropagation
The scaling factor keeps the variance of the dot-product scores controlled as d_k grows, which helps avoid softmax saturation and near-zero gradients.
PyTorch Implementations
Single Head Vanilla Attention PyTorch Code
Multi Head Vanilla Attention For loop PyTorch Code
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class MultiHeadVanillaAttention(nn.Module):
def __init__(self, d_model, num_heads, head_dim=None):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
if head_dim is None:
self.d_k = d_model//self.num_heads
else:
self.d_k = head_dim
self.Wq = nn.ModuleList([nn.Linear(in_features = self.d_model, out_features = self.d_k) for _ in range(self.num_heads)])
self.Wk = nn.ModuleList([nn.Linear(in_features = self.d_model, out_features = self.d_k) for _ in range(self.num_heads)])
self.Wv = nn.ModuleList([nn.Linear(in_features = self.d_model, out_features = self.d_k) for _ in range(self.num_heads)])
self.out = nn.Linear(in_features = self.d_k * self.num_heads, out_features = self.d_model)
def forward(self, x):
query = [self.Wq[i](x) for i in range(self.num_heads)] # [B, N, d_model] * [d_model, d_k] -> [B, N, d_k] TC=O(B*N*d_model*d_k)
key = [self.Wk[i](x) for i in range(self.num_heads)] #[B, N, d_model] * [d_model, d_k] -> [B, N, d_k] TC=O(B*N*d_model*d_k)
value = [self.Wv[i](x) for i in range(self.num_heads)] # [B, N, d_model] * [d_model, d_k] -> [B, N, d_k] TC=O(B*N*d_model*d_k)
scores = [torch.matmul(query[i], key[i].transpose(-1, -2))/(self.d_k**0.5) for i in range(self.num_heads)] # [B, N, d_k] * [B, d_k, N] -> [B, N, N] TC=O(B*N^2*d_k)
attn = [F.softmax(scores[i], dim=-1) for i in range(self.num_heads)] # [B, N, N]
final = [torch.matmul(attn[i], value[i]) for i in range(self.num_heads)] # [B, N, N] * [B, N, d_k] -> [B, N, d_k] TC=O(B*N^2*d_k)
final = torch.cat(final, dim=-1)
projout = self.out(final)
return projout # [B, N, d_model]
d_model = 512
d_k = 512
model = MultiHeadVanillaAttention(d_model = d_model, num_heads=2, head_dim=32).requires_grad_(False).eval().to(device)
n_tokens = 16384
inp = torch.randn(1, n_tokens, d_model, device=device)
torch.cuda.synchronize()
start = time.time()
k=50
for _ in range(k):
with torch.no_grad():
out = model(inp)
torch.cuda.synchronize()
print(f"Time taken: {(time.time()-start)/k:.3f}")Multi Head Vanilla Attention PyTorch Code
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class MultiHeadVanillaAttention(nn.Module):
def __init__(self, d_model, head_dim, num_heads):
super().__init__()
self.d_model = d_model
self.d_k = head_dim * num_heads
self.head_dim = head_dim
self.num_heads = num_heads
self.Wq = nn.Linear(in_features = self.d_model, out_features = self.d_k)
self.Wk = nn.Linear(in_features = self.d_model, out_features = self.d_k)
self.Wv = nn.Linear(in_features = self.d_model, out_features = self.d_k)
self.out = nn.Linear(in_features = self.d_k, out_features = self.d_model)
def forward(self, x):
batch_size = x.shape[0]
query = self.Wq(x) # [B, N, d_model] * [d_model, d_k] -> [B, N, d_k] TC=O(B*N*d_model*d_k)
key = self.Wk(x) #[B, N, d_model] * [d_model, d_k] -> [B, N, d_k] TC=O(B*N*d_model*d_k)
value = self.Wv(x) # [B, N, d_model] * [d_model, d_k] -> [B, N, d_k] TC=O(B*N*d_model*d_k)
query = query.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) # [B, N, d_k] -> [B, N, num_heads, head_dim] -> [B, num_heads, N, head_dim]
key = key.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) # [B, N, d_k] -> [B, N, num_heads, head_dim] -> [B, num_heads, N, head_dim]
value = value.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) # [B, N, d_k] -> [B, N, num_heads, head_dim] -> [B, num_heads, N, head_dim]
scores = torch.matmul(query, key.transpose(-1, -2))/(self.head_dim**0.5) # [B, num_heads, N, head_dim] * [B, num_heads, head_dim, N] -> [B, num_heads, N, N] TC=O(B*N^2*d_k)
attn = F.softmax(scores, dim=-1) # [B, num_heads, N, N]
final = torch.matmul(attn, value).permute(0, 2, 1, 3).reshape(batch_size, -1, self.d_k) # [B, num_heads, N, N] * [B, num_heads, N, head_dim] -> [B, N, d_k] TC=O(B*N^2*d_k)
projout = self.out(final) # [B, N, d_k] -> [B, N, d_model]
return projout # [B, N, d_model]
d_model = 512
d_k = 512
model = MultiHeadVanillaAttention(d_model = d_model, num_heads=2, head_dim=32).requires_grad_(False).eval().to(device)
n_tokens = 16384
inp = torch.randn(1, n_tokens, d_model, device=device)
torch.cuda.synchronize()
start = time.time()
k=50
for _ in range(k):
with torch.no_grad():
out = model(inp)
torch.cuda.synchronize()
print(f"Time taken: {(time.time()-start)/k:.3f}")Linear Attention
Instead of using softmax, it tries to make it more general where this softmax function can be replaced by any other function
Kernel is some function which takes in two input vectors (query and key) and gives a single scalar output i.e. similarity score.
This softmax can also be replaced by some kernel which is some kind of similarity function (i.e. to have property of a similarity function).
Kernel can be decomposed into linear functions (feature functions)
- K(a, b) = \phi(a)^T\phi(b)
Feature function takes a vector as input and projects it into new feature space
The basic idea is instead of doing complicated non-linear function like softmax, can’t we just project a and b into higher dimensional space and just do the linear inner product
The attention mechanism is defined as A_l(x) = V' = \text{softmax} \left( \frac{QK^T}{\sqrt{D}} \right) V.
We can make the above equation more generalized by using a similarity function
V'_i = \frac{\sum_{j=1}^N \text{sim}(Q_i, K_j) V_j}{\sum_{j=1}^N \text{sim}(Q_i, K_j)}
where we can think of \text{sim}(q, k) = \exp \left( \frac{q^T k}{\sqrt{D}} \right)
only constraint on sim is to give non-negative outputs
given a kernel with feature representation \varphi(x), we can rewrite the above equation
V'_i = \frac{\sum_{j=1}^N \phi(Q_i)^T \phi(K_j) V_j}{\sum_{j=1}^N \phi(Q_i)^T \phi(K_j)}
By making the use of associative property of matrix multiplication, we can take out Q from summation, because it is independent of j
V'_i = \frac{\phi(Q_i)^T \sum_{j=1}^N \phi(K_j) V_j^T}{\phi(Q_i)^T \sum_{j=1}^N \phi(K_j)} (4)
And finally we can convert it into vectorized format
\left( \phi(Q) \phi(K)^T \right) V = \phi(Q) \left( \phi(K)^T V \right)
It is linear because we can compute numerator and denominator and use it for all the queries [The eqn before vectorized one]
For softmax attention, complexity is O(N^2 * max(D, M)), where D is dimensionality for query and key, M is for value
For linear attention, complexity is O(NCM), where C is feature map projection dim and M is for value
Linearization of exact softmax attention is not feasible because of the exponential kernel, which is of infinite dimension and can’t be stored on a computer (expansion of e^x is infinitely many terms)
That’s why it is better to use some kind of polynomial feature representation
The author chose elu instead of relu to avoid setting gradients to 0 when input was negative
The key trick in linear attention is to avoid explicitly creating \phi(Q)\phi(K)^T first. By computing \phi(K)^T V before multiplying with \phi(Q), the attention can be computed more efficiently for long sequences.
Single Head Linear Attention PyTorch Code
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class SingleHeadLinearAttention(nn.Module):
def __init__(self, in_features, out_features):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.Wq = nn.Linear(in_features = self.in_features, out_features = self.out_features)
self.Wk = nn.Linear(in_features = self.in_features, out_features = self.out_features)
self.Wv = nn.Linear(in_features = self.in_features, out_features = self.out_features)
def forward(self, x):
query = self.Wq(x) # [B, N, d_model] * [d_model, d_k] -> [B, N, d_k] TC=O(B*N*d_model*d_k)
key = self.Wk(x) #[B, N, d_model] * [d_model, d_k] -> [B, N, d_k] TC=O(B*N*d_model*d_k)
value = self.Wv(x) # [B, N, d_model] * [d_model, d_k] -> [B, N, d_k] TC=O(B*N*d_model*d_k)
query = F.elu(query) + 1 # [B, N, d_k]
key = F.elu(key) + 1 # [B, N, d_k]
scores = torch.matmul(key.transpose(-1, -2), value) # [B, d_k, N] * [B, N, d_k] -> [B, d_k, d_k] TC=O(B*N*d_k^2)
final = torch.matmul(query, scores) # [B, N, d_k] * [B, d_k, d_k] -> [B, N, d_k] TC=O(B*N*d_k^2)
k_sum = torch.sum(key, dim=-2).unsqueeze(-1)
normalizer = torch.matmul(query, k_sum) # [B, N , d_k] * [B, d_k, 1] -> [B, N, 1] TC=O(B*N*d_k)
eps = 1e-6 # To use in case we don't divide by 0
return final/(normalizer + eps) # [B, N, d_k]
d_model = 512
d_k = 512
model = SingleHeadLinearAttention(in_features = d_model, out_features = d_k).requires_grad_(False).eval().to(device)
n_tokens = 16384
inp = torch.randn(1, n_tokens, d_model, device=device)
torch.cuda.synchronize()
start = time.time()
k=50
for _ in range(k):
with torch.no_grad():
out = model(inp)
torch.cuda.synchronize()
print(f"Time taken: {(time.time()-start)/k:.3f}")| Method | Main idea | Complexity bottleneck |
|---|---|---|
| Single Head Vanilla Attention | One set of query, key, and value projections | Forms an N \times N attention matrix |
| Multi Head Vanilla Attention For loop | Computes each head separately and concatenates outputs | Still forms one N \times N matrix per head |
| Multi Head Vanilla Attention | Vectorizes the heads into one batched operation | Same quadratic attention cost, but better parallelism |
| Single Head Linear Attention | Uses a feature map and associativity to avoid explicit softmax attention | Depends on feature dimension instead of storing full N \times N attention |