import torch
import torch.nn as nn
import torch.nn.functional as F
import timedevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")class SingleHeadVanillaAttention(nn.Module):
def __init__(self, d_model, d_k):
super().__init__()
self.d_model = d_model
self.d_k = d_k
self.Wq = nn.Linear(in_features = self.d_model, out_features = self.d_k)
self.Wk = nn.Linear(in_features = self.d_model, out_features = self.d_k)
self.Wv = nn.Linear(in_features = self.d_model, out_features = self.d_k)
def forward(self, x):
query = self.Wq(x) # [B, N, d_model] * [d_model, d_k] -> [B, N, d_k] TC=O(B*N*d_model*d_k)
key = self.Wk(x) #[B, N, d_model] * [d_model, d_k] -> [B, N, d_k] TC=O(B*N*d_model*d_k)
value = self.Wv(x) # [B, N, d_model] * [d_model, d_k] -> [B, N, d_k] TC=O(B*N*d_model*d_k)
scores = torch.matmul(query, key.transpose(-1, -2))/(self.d_k**0.5) # [B, N, d_k] * [B, d_k, N] -> [B, N, N] TC=O(B*N^2*d_k)
attn = F.softmax(scores, dim=-1) # [B, N, N]
final = torch.matmul(attn, value) # [B, N, N] * [B, N, d_k] -> [B, N, d_k] TC=O(B*N^2*d_k)
return final # [B, N, d_k]class MultiHeadVanillaAttention(nn.Module):
def __init__(self, d_model, num_heads, head_dim=None):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
if head_dim is None:
self.d_k = d_model//self.num_heads
else:
self.d_k = head_dim
self.Wq = nn.ModuleList([nn.Linear(in_features = self.d_model, out_features = self.d_k) for _ in range(self.num_heads)])
self.Wk = nn.ModuleList([nn.Linear(in_features = self.d_model, out_features = self.d_k) for _ in range(self.num_heads)])
self.Wv = nn.ModuleList([nn.Linear(in_features = self.d_model, out_features = self.d_k) for _ in range(self.num_heads)])
self.out = nn.Linear(in_features = self.d_k * self.num_heads, out_features = self.d_model)
def forward(self, x):
query = [self.Wq[i](x) for i in range(self.num_heads)] # [B, N, d_model] * [d_model, d_k] -> [B, N, d_k] TC=O(B*N*d_model*d_k)
key = [self.Wk[i](x) for i in range(self.num_heads)] #[B, N, d_model] * [d_model, d_k] -> [B, N, d_k] TC=O(B*N*d_model*d_k)
value = [self.Wv[i](x) for i in range(self.num_heads)] # [B, N, d_model] * [d_model, d_k] -> [B, N, d_k] TC=O(B*N*d_model*d_k)
scores = [torch.matmul(query[i], key[i].transpose(-1, -2))/(self.d_k**0.5) for i in range(self.num_heads)] # [B, N, d_k] * [B, d_k, N] -> [B, N, N] TC=O(B*N^2*d_k)
attn = [F.softmax(scores[i], dim=-1) for i in range(self.num_heads)] # [B, N, N]
final = [torch.matmul(attn[i], value[i]) for i in range(self.num_heads)] # [B, N, N] * [B, N, d_k] -> [B, N, d_k] TC=O(B*N^2*d_k)
final = torch.cat(final, dim=-1)
projout = self.out(final)
return projout # [B, N, d_model]class MultiHeadVanillaAttention(nn.Module):
def __init__(self, d_model, head_dim, num_heads):
super().__init__()
self.d_model = d_model
self.d_k = head_dim * num_heads
self.head_dim = head_dim
self.num_heads = num_heads
self.Wq = nn.Linear(in_features = self.d_model, out_features = self.d_k)
self.Wk = nn.Linear(in_features = self.d_model, out_features = self.d_k)
self.Wv = nn.Linear(in_features = self.d_model, out_features = self.d_k)
self.out = nn.Linear(in_features = self.d_k, out_features = self.d_model)
def forward(self, x):
batch_size = x.shape[0]
query = self.Wq(x) # [B, N, d_model] * [d_model, d_k] -> [B, N, d_k] TC=O(B*N*d_model*d_k)
key = self.Wk(x) #[B, N, d_model] * [d_model, d_k] -> [B, N, d_k] TC=O(B*N*d_model*d_k)
value = self.Wv(x) # [B, N, d_model] * [d_model, d_k] -> [B, N, d_k] TC=O(B*N*d_model*d_k)
query = query.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) # [B, N, d_k] -> [B, N, num_heads, head_dim] -> [B, num_heads, N, head_dim]
key = key.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) # [B, N, d_k] -> [B, N, num_heads, head_dim] -> [B, num_heads, N, head_dim]
value = value.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) # [B, N, d_k] -> [B, N, num_heads, head_dim] -> [B, num_heads, N, head_dim]
scores = torch.matmul(query, key.transpose(-1, -2))/(self.head_dim**0.5) # [B, num_heads, N, head_dim] * [B, num_heads, head_dim, N] -> [B, num_heads, N, N] TC=O(B*N^2*d_k)
attn = F.softmax(scores, dim=-1) # [B, num_heads, N, N]
final = torch.matmul(attn, value).permute(0, 2, 1, 3).reshape(batch_size, -1, self.d_k) # [B, num_heads, N, N] * [B, num_heads, N, head_dim] -> [B, N, d_k] TC=O(B*N^2*d_k)
projout = self.out(final) # [B, N, d_k] -> [B, N, d_model]
return projout # [B, N, d_model]d_model = 512
d_k = 512model = MultiHeadVanillaAttention(d_model = d_model, num_heads=2, head_dim=32).requires_grad_(False).eval().to(device)n_tokens = 100
inp = torch.randn(1, n_tokens, d_model, device=device)torch.cuda.synchronize()
start = time.time()
k=5
for _ in range(k):
with torch.no_grad():
out = model(inp)
torch.cuda.synchronize()
print(f"Time taken: {(time.time()-start)/k:.3f}")out.shapeimport torch
import torch.nn as nn
import torch.nn.functional as F
import timedevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")class SingleHeadLinearaAttention(nn.Module):
def __init__(self, d_model, d_k):
super().__init__()
self.d_model = d_model
self.d_k = d_k
self.Wq = nn.Linear(in_features = self.d_model, out_features = self.d_k)
self.Wk = nn.Linear(in_features = self.d_model, out_features = self.d_k)
self.Wv = nn.Linear(in_features = self.d_model, out_features = self.d_k)
def forward(self, x):
query = self.Wq(x) # [B, N, d_model] * [d_model, d_k] -> [B, N, d_k] TC=O(B*N*d_model*d_k)
key = self.Wk(x) #[B, N, d_model] * [d_model, d_k] -> [B, N, d_k] TC=O(B*N*d_model*d_k)
value = self.Wv(x) # [B, N, d_model] * [d_model, d_k] -> [B, N, d_k] TC=O(B*N*d_model*d_k)
query = F.elu(query) + 1 # [B, N, d_k]
key = F.elu(key) + 1 # [B, N, d_k]
scores = torch.matmul(key.transpose(-1, -2), value) # [B, d_k, N] * [B, N, d_k] -> [B, d_k, d_k] TC=O(B*N*d_k^2)
final = torch.matmul(query, scores) # [B, N, d_k] * [B, d_k, d_k] -> [B, N, d_k] TC=O(B*N*d_k^2)
k_sum = torch.sum(key, dim=-2).unsqueeze(-1)
normalizer = torch.matmul(query, k_sum) # [B, N , d_k] * [B, d_k, 1] -> [B, N, 1] TC=O(B*N*d_k)
eps = 1e-6 # To use in case we dont divide by 0
return final/(normalizer + eps) # [B, N, d_k]d_model = 512
d_k = 512model = SingleHeadLinearaAttention(d_model = d_model, d_k = d_k).requires_grad_(False).eval().to(device)n_tokens = 16384
inp = torch.randn(1, n_tokens, d_model, device=device)torch.cuda.synchronize()
start = time.time()
k=50
for _ in range(k):
with torch.no_grad():
out = model(inp)
torch.cuda.synchronize()
print(f"Time taken: {(time.time()-start)/k:.3f}")Time taken: 0.039