Attention

Deep Learning
Author

Ritesh Kumar Maurya

Published

January 12, 2026

Issue with RNNs

  • Informaton Bottleneck:- If sentence is too much long, they needed to compress all the previous token information in just a single vector h
  • Sequential Nature:- All the tokens being processed one by one

Quotient rule

  • Will be used while finding out the derivative of softmax
  • \left(\frac{u}{v}\right)' = \frac{u'v - uv'}{v^2}.

Transformers

  • Given a sequence of length N, where each token is represented by d_model numbers, we compute three matrices
    • Query = XWq (what I’m looking for)
    • Key = XWk (what I have)
    • Value = XWv (the information I’ll give if I’m relevant)
    • Wq and Wk are of same shape and to maintain the symmetry Wv is also kept to be of same shape, Wq,Wk [d_model, d_k] and Wv [d_model, Dv]
  • Now to find out which key is more relevant to given queries we do matrix multiplication and the finally apply softmax to get probability
    • score = QueryKey^T
    • attn = softmax(score/sqrt(d_k))
    • Why scaling factor
      • this score dot product has variance of d_k, which causes instability in backpropagation
        • lets say we have two vectors q and k of dimension d_k
        • E[qi]=E[ki] = 0
        • var[qi]=var[ki] = 1
        • then E[q.k] = 0 but var[q.k] = d_k link
      • the derivative of softmax contains some terms which if not normalizaied lead to gradient nearly 0 becuase of this larger variance
      • \text{Softmax}(z_i) = \frac{e^{z_i}}{\sum_j e^{z_j}}
      • \frac{\partial \text{Softmax}(z)_i}{\partial z_j} = \text{Softmax}(z)_i (\delta_{ij} - \text{Softmax}(z)_j)
      • If z_i is much larger than other elements, \text{Softmax}(z)_i \approx 1.
      • If z_i is much smaller, \text{Softmax}(z)_i \approx 0.
      • So incase of jacobian matrix, considering a simple case where query 0 attends all the keys
        • Diagonal: s_1(1-s_1) \approx 1(0) = 0.
        • Off-Diagonal: -s_1 s_2 \approx -1(0) = 0.
      • to enforce unit variance, we apply normalization by dividng the score by standard deviation i.e. sqrt(d_k)

Single Head Vanilla Attention PyTorch Code

import torch
import torch.nn as nn
import torch.nn.functional as F
import time

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class SingleHeadVanillaAttention(nn.Module):
    def __init__(self, d_model, d_k):
        super().__init__()
        self.d_model = d_model
        self.d_k = d_k
        self.Wq = nn.Linear(in_features = self.d_model, out_features = self.d_k)
        self.Wk = nn.Linear(in_features = self.d_model, out_features = self.d_k)
        self.Wv = nn.Linear(in_features = self.d_model, out_features = self.d_k)

    def forward(self, x):
        query = self.Wq(x) # [B, N, d_model] * [d_model, d_k] -> [B, N, d_k] TC=O(B*N*d_model*d_k)
        key = self.Wk(x)   #[B, N, d_model] * [d_model, d_k] -> [B, N, d_k] TC=O(B*N*d_model*d_k)
        value = self.Wv(x) # [B, N, d_model] * [d_model, d_k] -> [B, N, d_k] TC=O(B*N*d_model*d_k)

        scores = torch.matmul(query, key.transpose(-1, -2))/(self.d_k**0.5) # [B, N, d_k] * [B, d_k, N] -> [B, N, N] TC=O(B*N^2*d_k)

        attn = F.softmax(scores, dim=-1) # [B, N, N]

        final = torch.matmul(attn, value) # [B, N, N] * [B, N, d_k] -> [B, N, d_k]  TC=O(B*N^2*d_k)
        return final # [B, N, d_k]

d_model = 512
d_k = 512

model = SingleHeadVanillaAttention(d_model = d_model, d_k = d_k).requires_grad_(False).eval().to(device)

n_tokens = 16384
inp = torch.randn(1, n_tokens, d_model, device=device)

torch.cuda.synchronize()
start = time.time()
k=50
for _ in range(k):
    with torch.no_grad():
        out = model(inp)
torch.cuda.synchronize()
print(f"Time taken: {(time.time()-start)/k:.3f}")
Time taken: 0.457

Multi Head Vanilla Attention For loop PyTorch Code

import torch
import torch.nn as nn
import torch.nn.functional as F
import time

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class MultiHeadVanillaAttention(nn.Module):
    def __init__(self, d_model, num_heads, head_dim=None):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        if head_dim is None:
            self.d_k = d_model//self.num_heads
        else:
            self.d_k = head_dim
        self.Wq = nn.ModuleList([nn.Linear(in_features = self.d_model, out_features = self.d_k) for _ in range(self.num_heads)])
        self.Wk = nn.ModuleList([nn.Linear(in_features = self.d_model, out_features = self.d_k) for _ in range(self.num_heads)])
        self.Wv = nn.ModuleList([nn.Linear(in_features = self.d_model, out_features = self.d_k) for _ in range(self.num_heads)])
        self.out = nn.Linear(in_features = self.d_k * self.num_heads, out_features = self.d_model)

    def forward(self, x):
        query = [self.Wq[i](x) for i in range(self.num_heads)] # [B, N, d_model] * [d_model, d_k] -> [B, N, d_k] TC=O(B*N*d_model*d_k)
        key = [self.Wk[i](x) for i in range(self.num_heads)]   #[B, N, d_model] * [d_model, d_k] -> [B, N, d_k] TC=O(B*N*d_model*d_k)
        value = [self.Wv[i](x) for i in range(self.num_heads)] # [B, N, d_model] * [d_model, d_k] -> [B, N, d_k] TC=O(B*N*d_model*d_k)

        scores = [torch.matmul(query[i], key[i].transpose(-1, -2))/(self.d_k**0.5) for i in range(self.num_heads)] # [B, N, d_k] * [B, d_k, N] -> [B, N, N] TC=O(B*N^2*d_k)

        attn = [F.softmax(scores[i], dim=-1) for i in range(self.num_heads)] # [B, N, N]

        final = [torch.matmul(attn[i], value[i]) for i in range(self.num_heads)] # [B, N, N] * [B, N, d_k] -> [B, N, d_k]  TC=O(B*N^2*d_k)
        final = torch.cat(final, dim=-1)

        projout = self.out(final)

        return projout # [B, N, d_model]

d_model = 512
d_k = 512

model = MultiHeadVanillaAttention(d_model = d_model, num_heads=2, head_dim=32).requires_grad_(False).eval().to(device)

n_tokens = 16384
inp = torch.randn(1, n_tokens, d_model, device=device)

torch.cuda.synchronize()
start = time.time()
k=50
for _ in range(k):
    with torch.no_grad():
        out = model(inp)
torch.cuda.synchronize()
print(f"Time taken: {(time.time()-start)/k:.3f}")
Time taken: 0.769

Multi Head Vanilla Attention PyTorch Code

import torch
import torch.nn as nn
import torch.nn.functional as F
import time

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class MultiHeadVanillaAttention(nn.Module):
    def __init__(self, d_model, head_dim, num_heads):
        super().__init__()
        self.d_model = d_model
        self.d_k = head_dim * num_heads
        self.head_dim = head_dim
        self.num_heads = num_heads
        self.Wq = nn.Linear(in_features = self.d_model, out_features = self.d_k)
        self.Wk = nn.Linear(in_features = self.d_model, out_features = self.d_k)
        self.Wv = nn.Linear(in_features = self.d_model, out_features = self.d_k)

        self.out = nn.Linear(in_features = self.d_k, out_features = self.d_model)
    def forward(self, x):
        batch_size = x.shape[0]
        query = self.Wq(x) # [B, N, d_model] * [d_model, d_k] -> [B, N, d_k] TC=O(B*N*d_model*d_k)
        key = self.Wk(x)   #[B, N, d_model] * [d_model, d_k] -> [B, N, d_k] TC=O(B*N*d_model*d_k)
        value = self.Wv(x) # [B, N, d_model] * [d_model, d_k] -> [B, N, d_k] TC=O(B*N*d_model*d_k)

        query = query.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) # [B, N, d_k] -> [B, N, num_heads, head_dim] -> [B, num_heads, N, head_dim]
        key = key.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) # [B, N, d_k] -> [B, N, num_heads, head_dim] -> [B, num_heads, N, head_dim]
        value = value.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) # [B, N, d_k] -> [B, N, num_heads, head_dim] -> [B, num_heads, N, head_dim]


        scores = torch.matmul(query, key.transpose(-1, -2))/(self.head_dim**0.5) # [B, num_heads, N, head_dim] * [B, num_heads, head_dim, N] -> [B, num_heads, N, N] TC=O(B*N^2*d_k)

        attn = F.softmax(scores, dim=-1) # [B, num_heads, N, N]

        final = torch.matmul(attn, value).permute(0, 2, 1, 3).reshape(batch_size, -1, self.d_k) # [B, num_heads, N, N] * [B, num_heads, N, head_dim] -> [B, N, d_k]  TC=O(B*N^2*d_k)
        projout = self.out(final) # [B, N, d_k] -> [B, N, d_model]
        return projout # [B, N, d_model]


d_model = 512
d_k = 512

model = MultiHeadVanillaAttention(d_model = d_model, num_heads=2, head_dim=32).requires_grad_(False).eval().to(device)

n_tokens = 16384
inp = torch.randn(1, n_tokens, d_model, device=device)

torch.cuda.synchronize()
start = time.time()
k=50
for _ in range(k):
    with torch.no_grad():
        out = model(inp)
torch.cuda.synchronize()
print(f"Time taken: {(time.time()-start)/k:.3f}")
Time taken: 2.617

Linear Attention

  • Instead of using softmax, it tries to make it more general where this softmax function can be replaced by any other function

  • Kernel is some function which takes in two input vectors (query and key) and gives a single scalar output i.e. similarity score.

  • This softmax can also be replaced by some kernel which is some kind of similarity function (i.e. to have property of a similarity function).

  • Kernel can be decomposed into linear functions (feature functions)

    • K(a, b) = phi(a)Tphi(b)
  • Feature function takes a vector as input and projects it into new feature space

  • The basic idea is instead of doing complicated non-linear function like softmax, cant we just project a and b into highr dimensional space and just do the linear inner product

  • The attention mechanism is defined as A_l(x) = V' = \text{softmax} \left( \frac{QK^T}{\sqrt{D}} \right) V.

  • We can make the above equation more generalised by using a simlarity function

  • V'_i = \frac{\sum_{j=1}^N \text{sim}(Q_i, K_j) V_j}{\sum_{j=1}^N \text{sim}(Q_i, K_j)}

  • where we can think of \text{sim}(q, k) = \exp \left( \frac{q^T k}{\sqrt{D}} \right)

  • only constraint on sim is to give non-negative outputs

  • given a kernel with feature representation \varphi(x), we can rewrite the above equation

  • V'_i = \frac{\sum_{j=1}^N \phi(Q_i)^T \phi(K_j) V_j}{\sum_{j=1}^N \phi(Q_i)^T \phi(K_j)}

  • By making the use of associative property of matrix multiplication, we can take out Q from summation, because it is independet of j

  • V'_i = \frac{\phi(Q_i)^T \sum_{j=1}^N \phi(K_j) V_j^T}{\phi(Q_i)^T \sum_{j=1}^N \phi(K_j)} (4)

  • And finally we can convert it into vectorized format

  • \left( \phi(Q) \phi(K)^T \right) V = \phi(Q) \left( \phi(K)^T V \right)

  • It is linear because we can compute numerator and denominator and use it for all the queries [The eqn before vectorized one]

  • For softmax attention, complexity is O(N^2 * max(D, M)), where D is dimensionality for query and key, M is for value

  • For linear attention, complexity is O(NCM), where C is feature map projection dim and M is for value

  • Linearization of exact spftmax attention is not feasible because of the exponential kernel, which is of infinte dimension and cant be stored on a computer (expansion of e^x is infinitely many terms)

  • Thats why it is better to use some kind of polynomial feature representation

  • The author chose elu instead of relu to avoid setting gradients to 0 when input was negative

Single Head Linear Attention PyTorch Code

import torch
import torch.nn as nn
import torch.nn.functional as F
import time

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class SingleHeadLinearaAttention(nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.Wq = nn.Linear(in_features = self.in_features, out_features = self.out_features)
        self.Wk = nn.Linear(in_features = self.in_features, out_features = self.out_features)
        self.Wv = nn.Linear(in_features = self.in_features, out_features = self.out_features)

    def forward(self, x):
        query = self.Wq(x) # [B, N, d_model] * [d_model, d_k] -> [B, N, d_k] TC=O(B*N*d_model*d_k)
        key = self.Wk(x)   #[B, N, d_model] * [d_model, d_k] -> [B, N, d_k] TC=O(B*N*d_model*d_k)
        value = self.Wv(x) # [B, N, d_model] * [d_model, d_k] -> [B, N, d_k] TC=O(B*N*d_model*d_k)

        query = F.elu(query) + 1 # [B, N, d_k]
        key = F.elu(key) + 1 # [B, N, d_k]

        scores = torch.matmul(key.transpose(-1, -2), value) # [B, d_k, N] * [B, N, d_k] -> [B, d_k, d_k] TC=O(B*N*d_k^2)

        final = torch.matmul(query, scores) # [B, N, d_k] * [B, d_k, d_k] -> [B, N, d_k] TC=O(B*N*d_k^2)
        k_sum = torch.sum(key, dim=-2).unsqueeze(-1)

        normalizer = torch.matmul(query, k_sum) # [B, N , d_k] * [B, d_k, 1] -> [B, N, 1] TC=O(B*N*d_k)
        eps = 1e-6 # To use in case we dont divide by 0

        return final/(normalizer + eps) # [B, N, d_k]

d_model = 512
d_k = 512

model = SingleHeadLinearaAttention(in_features = d_model, out_features = d_k).requires_grad_(False).eval().to(device)

n_tokens = 16384
inp = torch.randn(1, n_tokens, d_model, device=device)

torch.cuda.synchronize()
start = time.time()
k=50
for _ in range(k):
    with torch.no_grad():
        out = model(inp)
torch.cuda.synchronize()
print(f"Time taken: {(time.time()-start)/k:.3f}")
Time taken: 0.038