import torch
import torch.nn as nn
import torch.nn.functional as F
import time
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class SingleHeadVanillaAttention(nn.Module):
    def __init__(self, d_model, d_k):
        super().__init__()
        self.d_model = d_model
        self.d_k = d_k
        self.Wq = nn.Linear(in_features = self.d_model, out_features = self.d_k)
        self.Wk = nn.Linear(in_features = self.d_model, out_features = self.d_k)
        self.Wv = nn.Linear(in_features = self.d_model, out_features = self.d_k)

    def forward(self, x):
        query = self.Wq(x) # [B, N, d_model] * [d_model, d_k] -> [B, N, d_k] TC=O(B*N*d_model*d_k)
        key = self.Wk(x)   #[B, N, d_model] * [d_model, d_k] -> [B, N, d_k] TC=O(B*N*d_model*d_k)
        value = self.Wv(x) # [B, N, d_model] * [d_model, d_k] -> [B, N, d_k] TC=O(B*N*d_model*d_k)

        scores = torch.matmul(query, key.transpose(-1, -2))/(self.d_k**0.5) # [B, N, d_k] * [B, d_k, N] -> [B, N, N] TC=O(B*N^2*d_k)

        attn = F.softmax(scores, dim=-1) # [B, N, N]

        final = torch.matmul(attn, value) # [B, N, N] * [B, N, d_k] -> [B, N, d_k]  TC=O(B*N^2*d_k)
        return final # [B, N, d_k]
class MultiHeadVanillaAttention(nn.Module):
    def __init__(self, d_model, num_heads, head_dim=None):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        if head_dim is None:
            self.d_k = d_model//self.num_heads
        else:
            self.d_k = head_dim
        self.Wq = nn.ModuleList([nn.Linear(in_features = self.d_model, out_features = self.d_k) for _ in range(self.num_heads)])
        self.Wk = nn.ModuleList([nn.Linear(in_features = self.d_model, out_features = self.d_k) for _ in range(self.num_heads)])
        self.Wv = nn.ModuleList([nn.Linear(in_features = self.d_model, out_features = self.d_k) for _ in range(self.num_heads)])
        self.out = nn.Linear(in_features = self.d_k * self.num_heads, out_features = self.d_model)

    def forward(self, x):
        query = [self.Wq[i](x) for i in range(self.num_heads)] # [B, N, d_model] * [d_model, d_k] -> [B, N, d_k] TC=O(B*N*d_model*d_k)
        key = [self.Wk[i](x) for i in range(self.num_heads)]   #[B, N, d_model] * [d_model, d_k] -> [B, N, d_k] TC=O(B*N*d_model*d_k)
        value = [self.Wv[i](x) for i in range(self.num_heads)] # [B, N, d_model] * [d_model, d_k] -> [B, N, d_k] TC=O(B*N*d_model*d_k)

        scores = [torch.matmul(query[i], key[i].transpose(-1, -2))/(self.d_k**0.5) for i in range(self.num_heads)] # [B, N, d_k] * [B, d_k, N] -> [B, N, N] TC=O(B*N^2*d_k)

        attn = [F.softmax(scores[i], dim=-1) for i in range(self.num_heads)] # [B, N, N]

        final = [torch.matmul(attn[i], value[i]) for i in range(self.num_heads)] # [B, N, N] * [B, N, d_k] -> [B, N, d_k]  TC=O(B*N^2*d_k)
        final = torch.cat(final, dim=-1)

        projout = self.out(final)

        return projout # [B, N, d_model]
class MultiHeadVanillaAttention(nn.Module):
    def __init__(self, d_model, head_dim, num_heads):
        super().__init__()
        self.d_model = d_model
        self.d_k = head_dim * num_heads
        self.head_dim = head_dim
        self.num_heads = num_heads
        self.Wq = nn.Linear(in_features = self.d_model, out_features = self.d_k)
        self.Wk = nn.Linear(in_features = self.d_model, out_features = self.d_k)
        self.Wv = nn.Linear(in_features = self.d_model, out_features = self.d_k)

        self.out = nn.Linear(in_features = self.d_k, out_features = self.d_model)
    def forward(self, x):
        batch_size = x.shape[0]
        query = self.Wq(x) # [B, N, d_model] * [d_model, d_k] -> [B, N, d_k] TC=O(B*N*d_model*d_k)
        key = self.Wk(x)   #[B, N, d_model] * [d_model, d_k] -> [B, N, d_k] TC=O(B*N*d_model*d_k)
        value = self.Wv(x) # [B, N, d_model] * [d_model, d_k] -> [B, N, d_k] TC=O(B*N*d_model*d_k)

        query = query.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) # [B, N, d_k] -> [B, N, num_heads, head_dim] -> [B, num_heads, N, head_dim]
        key = key.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) # [B, N, d_k] -> [B, N, num_heads, head_dim] -> [B, num_heads, N, head_dim]
        value = value.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) # [B, N, d_k] -> [B, N, num_heads, head_dim] -> [B, num_heads, N, head_dim]


        scores = torch.matmul(query, key.transpose(-1, -2))/(self.head_dim**0.5) # [B, num_heads, N, head_dim] * [B, num_heads, head_dim, N] -> [B, num_heads, N, N] TC=O(B*N^2*d_k)

        attn = F.softmax(scores, dim=-1) # [B, num_heads, N, N]

        final = torch.matmul(attn, value).permute(0, 2, 1, 3).reshape(batch_size, -1, self.d_k) # [B, num_heads, N, N] * [B, num_heads, N, head_dim] -> [B, N, d_k]  TC=O(B*N^2*d_k)
        projout = self.out(final) # [B, N, d_k] -> [B, N, d_model]
        return projout # [B, N, d_model]
d_model = 512
d_k = 512
model = MultiHeadVanillaAttention(d_model = d_model, num_heads=2, head_dim=32).requires_grad_(False).eval().to(device)
n_tokens = 100
inp = torch.randn(1, n_tokens, d_model, device=device)
torch.cuda.synchronize()
start = time.time()
k=5
for _ in range(k):
    with torch.no_grad():
        out = model(inp)
torch.cuda.synchronize()
print(f"Time taken: {(time.time()-start)/k:.3f}")
out.shape
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class SingleHeadLinearaAttention(nn.Module):
    def __init__(self, d_model, d_k):
        super().__init__()
        self.d_model = d_model
        self.d_k = d_k
        self.Wq = nn.Linear(in_features = self.d_model, out_features = self.d_k)
        self.Wk = nn.Linear(in_features = self.d_model, out_features = self.d_k)
        self.Wv = nn.Linear(in_features = self.d_model, out_features = self.d_k)

    def forward(self, x):
        query = self.Wq(x) # [B, N, d_model] * [d_model, d_k] -> [B, N, d_k] TC=O(B*N*d_model*d_k)
        key = self.Wk(x)   #[B, N, d_model] * [d_model, d_k] -> [B, N, d_k] TC=O(B*N*d_model*d_k)
        value = self.Wv(x) # [B, N, d_model] * [d_model, d_k] -> [B, N, d_k] TC=O(B*N*d_model*d_k)

        query = F.elu(query) + 1 # [B, N, d_k]
        key = F.elu(key) + 1 # [B, N, d_k]

        scores = torch.matmul(key.transpose(-1, -2), value) # [B, d_k, N] * [B, N, d_k] -> [B, d_k, d_k] TC=O(B*N*d_k^2)

        final = torch.matmul(query, scores) # [B, N, d_k] * [B, d_k, d_k] -> [B, N, d_k] TC=O(B*N*d_k^2)
        k_sum = torch.sum(key, dim=-2).unsqueeze(-1)

        normalizer = torch.matmul(query, k_sum) # [B, N , d_k] * [B, d_k, 1] -> [B, N, 1] TC=O(B*N*d_k)
        eps = 1e-6 # To use in case we dont divide by 0

        return final/(normalizer + eps) # [B, N, d_k]
d_model = 512
d_k = 512
model = SingleHeadLinearaAttention(d_model = d_model, d_k = d_k).requires_grad_(False).eval().to(device)
n_tokens = 16384
inp = torch.randn(1, n_tokens, d_model, device=device)
torch.cuda.synchronize()
start = time.time()
k=50
for _ in range(k):
    with torch.no_grad():
        out = model(inp)
torch.cuda.synchronize()
print(f"Time taken: {(time.time()-start)/k:.3f}")
Time taken: 0.039