Quantization in Depth

Optimization
Author

Ritesh Kumar Maurya

Published

May 24, 2024

Quantize and De-quantize a tensor

  • Advantages of Quantization
    • Smaller model
    • Speed gains
      • Memory bandwidth
      • Faster operations
        • GEMM: General Matrix Multiply(matrix to matrix multiplication)
        • GEMV: General Matrix Vector Multiplication (matrix to vector multiplication)
  • Challenges of Quantization
    • Quantization error
    • Retraining (Quantization Aware Training)
    • Limited Hardware support
    • Calibration dataset needed
    • packing/unpacking
  • getting q:-
    • r = s(q-z) q = int(round(r/s+z))
import torch
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap

#Helper functions to visualize
def plot_matrix(tensor, ax, title, vmin=0, vmax=1, cmap=None):
    """
    Plot a heatmap of tensors using seaborn
    """
    sns.heatmap(tensor.cpu().numpy(), ax=ax, vmin=vmin, vmax=vmax, cmap=cmap, annot=True, fmt=".2f", cbar=False)
    ax.set_title(title)
    ax.set_yticklabels([])
    ax.set_xticklabels([])


def plot_quantization_errors(original_tensor, quantized_tensor, dequantized_tensor, dtype = torch.int8, n_bits = 8):
    """
    A method that plots 4 matrices, the original tensor, the quantized tensor
    the de-quantized tensor and the error tensor.
    """
    # Get a figure of 4 plots
    fig, axes = plt.subplots(1, 4, figsize=(15, 4))

    # Plot the first matrix
    plot_matrix(original_tensor, axes[0], 'Original Tensor', cmap=ListedColormap(['white']))

    # Get the quantization range and plot the quantized tensor
    q_min, q_max = torch.iinfo(dtype).min, torch.iinfo(dtype).max
    plot_matrix(quantized_tensor, axes[1], f'{n_bits}-bit Linear Quantized Tensor', vmin=q_min, vmax=q_max, cmap='coolwarm')

    # Plot the de-quantized tensors
    plot_matrix(dequantized_tensor, axes[2], 'Dequantized Tensor', cmap='coolwarm')

    # Get the quantization errors
    q_error_tensor = abs(original_tensor - dequantized_tensor)
    plot_matrix(q_error_tensor, axes[3], 'Quantization Error Tensor', cmap=ListedColormap(['white']))

    fig.tight_layout()
    plt.show()

def linear_q_with_scale_and_zero_point(
    tensor, scale, zero_point, dtype = torch.int8):

    scaled_and_shifted_tensor = tensor / scale + zero_point

    rounded_tensor = torch.round(scaled_and_shifted_tensor)

    q_min = torch.iinfo(dtype).min
    q_max = torch.iinfo(dtype).max

    q_tensor = rounded_tensor.clamp(q_min,q_max).to(dtype)
    
    return q_tensor

test_tensor=torch.tensor(
    [[191.6, -13.5, 728.6],
     [92.14, 295.5,  -184],
     [0,     684.6, 245.5]])

scale = 3.5
zero_point = -70

quantized_tensor = linear_q_with_scale_and_zero_point(
    test_tensor, scale, zero_point)

def linear_dequantization(quantized_tensor, scale, zero_point):
    return scale * (quantized_tensor.float() - zero_point)

dequantized_tensor = linear_dequantization(
    quantized_tensor, scale, zero_point)

plot_quantization_errors(test_tensor, quantized_tensor,
                         dequantized_tensor)

Get the Scale and Zero-Point

  • s = (r_max-r_min)[current_tensor_range]/(q_max-q_min)[datatype_range]
  • z = int(round(q_min - r_min/s))
  • z and quantized tensor are of the same type
  • z is an integer because it represent zero(in the original ‘r’ range) with an integer in the quantized ‘q’ range
  • if z goes out of range:-
    • z < q_min:-
      • z = q_min
    • z > q_max:-
      • z = q_max
import torch

def get_q_scale_and_zero_point(tensor, dtype=torch.int8):
    
    q_min, q_max = torch.iinfo(dtype).min, torch.iinfo(dtype).max
    r_min, r_max = tensor.min().item(), tensor.max().item()

    scale = (r_max - r_min) / (q_max - q_min)

    zero_point = q_min - (r_min / scale)

    # clip the zero_point to fall in [quantized_min, quantized_max]
    if zero_point < q_min:
        zero_point = q_min
    elif zero_point > q_max:
        zero_point = q_max
    else:
        # round and cast to int
        zero_point = int(round(zero_point))
    
    return scale, zero_point

Symmetric vs Asymmetrci Mode

  • Assymetric Mode:-
    • map [r_max, r_min] to [q_max, q_min]
    • This is what we have implemnted above
  • Symmetric Mode:-
    • map [-r_max, r_max] to [-q_max, q_max]
      • where r_max = max(|tensor|)
We don’t need to use zero point(z=0). this happens because the floating point range and the quantized range are symmetric with respect to zero

Hence, we can simplify the equation to:-

  • q = int(round(r/s))
  • s = r_max/q_max
import torch

def get_q_scale_symmetric(tensor, dtype=torch.int8):
    r_max = tensor.abs().max().item()
    q_max = torch.iinfo(dtype).max

    # return the scale
    return r_max/q_max

def linear_q_symmetric(tensor, dtype=torch.int8):
    scale = get_q_scale_symmetric(tensor)
    
    quantized_tensor = linear_q_with_scale_and_zero_point(tensor,
                                                     scale=scale,
                   # in symmetric quantization zero point is = 0    
                                                    zero_point=0,
                                                      dtype=dtype)
    
    return quantized_tensor, scale

Trade-off

  • Utilization of quantized range:

    • when using asymmetric quantization, the quantized range is fully utilized
    • When symmetric mode, if the float range is biased towards one side, this will result in a quantized range where a part of the range is dedicated to values that we’ll never see.(e.g ReLU where the output is positive)
  • Simplicity:

    • Symmetric mode is much simpler compared to asymmetric mode.
  • Memory: We don’t have to store zero-point for symmetric quantization

  • We use symmetric quantization for 8-bit, but as we go for lower bits such as 2 or 4 bits, we use asyyemtric quantization

Finer Granularity for more Precision

  • Different granularities
    • per tensor
    • per channel (along an axis)
    • per group (group n elements together)
  • The more granular quantization is the more precise it will be.

Per Channel Quantization

  • we usually use per channel quantization in int8
def linear_q_symmetric_per_channel(r_tensor, dim, dtype=torch.int8):
    
    output_dim = r_tensor.shape[dim]
    # store the scales
    scale = torch.zeros(output_dim)

    for index in range(output_dim):
        sub_tensor = r_tensor.select(dim, index)
        scale[index] = get_q_scale_symmetric(sub_tensor, dtype=dtype)

    # reshape the scale
    scale_shape = [1] * r_tensor.dim()
    scale_shape[dim] = -1
    scale = scale.view(scale_shape)
    quantized_tensor = linear_q_with_scale_and_zero_point(
        r_tensor, scale=scale, zero_point=0, dtype=dtype)
   
    return quantized_tensor, scale

Per Group Quantization

  • Group n(e.g. 32, 64, 128) elements together and quantize

  • Per group quantization can require a lot of memory

    • Let’s say we want to quantize a tensor in 4-bit and we choose group_size=32, symmetric mode(z=0), and we store the scales in FP16
    • It means that we actually quantizing the tensor in 4.5 bits since we have:
      • 4 bit(each element is stored in 4 bit)
      • 16/32 bit (scale in 16 bits for every 32 elements)
def linear_q_symmetric_per_group(tensor, group_size,
                                 dtype=torch.int8):
    
    t_shape = tensor.shape
    assert t_shape[1] % group_size == 0
    assert tensor.dim() == 2
    
    tensor = tensor.view(-1, group_size)
    
    quantized_tensor, scale = linear_q_symmetric_per_channel(
                                tensor, dim=0, dtype=dtype)
    
    quantized_tensor = quantized_tensor.view(t_shape)
    
    return quantized_tensor, scale

def linear_dequantization_per_group(quantized_tensor, scale, 
                                    group_size):
    
    q_shape = quantized_tensor.shape
    quantized_tensor = quantized_tensor.view(-1, group_size)
    
    dequantized_tensor = linear_dequantization(quantized_tensor, 
                                               scale, 0)
    
    dequantized_tensor = dequantized_tensor.view(q_shape)
    
    return dequantized_tensor

Quantizing Weights and Activations for Inference

  • Depending on what we quantize, the storage and the computation are not the same.
  • W8A32
    • If weights are quantized but not the activations, then computation is done floating point (FP16,FP32, BF16)
    • We need to dequantize the weights to perform the floating point computation (cast to float32)
  • W8A8
    • Both are quantized
    • Computation is integer based but not supported by all hardware

Custom Build an 8-Bit Quantizer

#W8A16LinearLayer
def w8_a16_forward(weight, input, scales, bias=None):
    
    casted_weights = weight.to(input.dtype)
    output = F.linear(input, casted_weights) * scales
    
    if bias is not None:
        output = output + bias
      
    return output
import torch
import torch.nn as nn
import torch.nn.functional as F

class W8A16LinearLayer(nn.Module):
    def __init__(self, in_features, out_features, 
                 bias=True, dtype=torch.float32):
        super().__init__()
        
        self.int8_weights = nn.Parameter(torch.Tensor([0, 1]
                                     ).to(dtype=torch.int8))

try:
    
    W8A16LinearLayer(1, 1)
    
except Exception as error:
    print("\033[91m", type(error).__name__, ": ", error, "\033[0m")
 RuntimeError :  Only Tensors of floating point and complex dtype can require gradients 
  • When we create nn.Parameters, pytorch expects that parameter where it’s able to compute gradients on it.

  • The issue is that with PyTorch, you can’t explicitly compute gradients on INT8 tensors.

  • So above code snippet will give an error saying that only tensors of floating point and complex dtype can require gradients.

  • So the right approach to save INT8 weights is instead of saving attributes as being an endless parameter, is to call the method called register buffer.

  • This way instead of storing a parameter, we just store a buffer, meaning we don’t need to compute gradients on the tensor.

  • You can initialize it with whatever dtype you want.

import torch
import torch.nn as nn
import torch.nn.functional as F

class W8A16LinearLayer(nn.Module):
    def __init__(self, in_features, out_features, 
                 bias=True, dtype=torch.float32):
        super().__init__()
        
        
        self.register_buffer(
            "int8_weights",
            torch.randint(
                -128, 127, (out_features, in_features), dtype=torch.int8
            )
        )
        
        self.register_buffer("scales", 
                             torch.randn((out_features), dtype=dtype)) # We are intereseted in inference only
        
        if bias:
            self.register_buffer("bias", 
                                 torch.randn((1, out_features), 
                                             dtype=dtype))
        
        else:
            self.bias = None
    def forward(self, input):
        return w8_a16_forward(self.int8_weights, 
                              input, self.scales, self.bias)

Quantize a Base Model

import torch
import torch.nn as nn
import torch.nn.functional as F

class W8A16LinearLayer(nn.Module):
    def __init__(self, in_features, out_features, 
                 bias=True, dtype=torch.float32):
        super().__init__()
        
        
        self.register_buffer(
            "int8_weights",
            torch.randint(
                -128, 127, (out_features, in_features), dtype=torch.int8
            )
        )
        
        self.register_buffer("scales", 
                             torch.randn((out_features), dtype=dtype))
        
        if bias:
            self.register_buffer("bias", 
                                 torch.randn((1, out_features), 
                                             dtype=dtype))
        
        else:
            self.bias = None

    def quantize(self, weights):
        w_fp32 = weights.clone().to(torch.float32)

        scales = w_fp32.abs().max(dim=-1).values / 127
        scales = scales.to(weights.dtype)

        int8_weights = torch.round(weights
                        /scales.unsqueeze(1)).to(torch.int8)

        self.int8_weights = int8_weights
        self.scales = scales
    
    def forward(self, input):
        return w8_a16_forward(self.int8_weights, 
                              input, self.scales, self.bias)

module = W8A16LinearLayer(4, 8)
print("Weights before:\n" , module.int8_weights)
random_matrix = torch.randn((4, 8), dtype=torch.bfloat16)
module.quantize(random_matrix)
print("Weights After:\n" , module.int8_weights)
print("Average quantiation error:-",(random_matrix - module.int8_weights 
 * module.scales.unsqueeze(1)).abs().mean())
Weights before:
 tensor([[ -55,   55,   41,   27],
        [   1,  -81,   34,  126],
        [ -12,   35, -104,  100],
        [  63,  -58, -109,   10],
        [ 116,  -24,   -1, -124],
        [ -48,    4,   -9,   23],
        [ -63,  -48,  -18,  -26],
        [   9,  -82,   90,    9]], dtype=torch.int8)
Weights After:
 tensor([[-126,  -68,   46,  -70,  126,    9,   22,  -75],
        [ -10,  100,   96,   48,  -68,  126,  -20,   15],
        [  47,  102,  -14,   27, -127,   76,  -31,   25],
        [  40,  -55,   62,  127,  -10,    6,  -34,  -66]], dtype=torch.int8)
Average quantiation error:- tensor(0.0039, dtype=torch.bfloat16)

Replace PyTorch layers with Quantized Layers

  • For language models, it better to not quantize the last layer.
import torch
import torch.nn as nn
import torch.nn.functional as F

def replace_linear_with_target(module, 
                               target_class, module_name_to_exclude):
    for name, child in module.named_children():
        if isinstance(child, nn.Linear) and not \
          any([x == name for x in module_name_to_exclude]):
            old_bias = child.bias

            new_module = target_class(child.in_features, 
                                      child.out_features, 
                                      old_bias is not None, 
                                      child.weight.dtype)
            setattr(module, name, new_module)
            if old_bias is not None:
              getattr(module, name).bias = old_bias
        else:
            # Recursively call the function for nested modules
            replace_linear_with_target(
                child, target_class, module_name_to_exclude)


def replace_linear_with_target_and_quantize(module, 
                               target_class, module_name_to_exclude):
    for name, child in module.named_children():
        if isinstance(child, nn.Linear) and not \
        any([x == name for x in module_name_to_exclude]):
            old_bias = child.bias
            old_weight = child.weight

            new_module = target_class(child.in_features, 
                                      child.out_features, 
                                      old_bias is not None, 
                                      child.weight.dtype)
            setattr(module, name, new_module) # current module is replaced by new_module

            getattr(module, name).quantize(old_weight)
            
            if old_bias is not None:
              getattr(module, name).bias = old_bias
        else:
            # Recursively call the function for nested modules
            replace_linear_with_target_and_quantize(child, 
                     target_class, module_name_to_exclude)

class DummyModel(torch.nn.Module):
  def __init__(self):
    super().__init__()
    self.emb = torch.nn.Embedding(1, 1)
    # Try with bias
    self.linear_1 = nn.Linear(1, 1)
    # Try without bias
    self.linear_2 = nn.Linear(1, 1, bias=False)
    # Lm prediction head
    self.lm_head = nn.Linear(1, 1, bias=False)

model_1 = DummyModel()
model_2 = DummyModel()
replace_linear_with_target_and_quantize(model_1, W8A16LinearLayer, ["lm_head"])
print("model_1",model_1)

replace_linear_with_target_and_quantize(model_2, W8A16LinearLayer, [])
print("model_2",model_2)
model_1 DummyModel(
  (emb): Embedding(1, 1)
  (linear_1): W8A16LinearLayer()
  (linear_2): W8A16LinearLayer()
  (lm_head): Linear(in_features=1, out_features=1, bias=False)
)
model_2 DummyModel(
  (emb): Embedding(1, 1)
  (linear_1): W8A16LinearLayer()
  (linear_2): W8A16LinearLayer()
  (lm_head): W8A16LinearLayer()
)

Quantize any Open Source PyTorch Model

from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline

model_id = "Salesforce/codegen-350M-mono"

model = AutoModelForCausalLM.from_pretrained(model_id, 
                                    torch_dtype=torch.bfloat16, 
                                             low_cpu_mem_usage=True)
tokenizer = AutoTokenizer.from_pretrained(model_id)

pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
print(pipe("def hello_world():", max_new_tokens=20, do_sample=False)[0]["generated_text"])
print("Model before:\n\n", model)
replace_linear_with_target_and_quantize(model, 
                                        W8A16LinearLayer, ["lm_head"])

print("Model after:\n\n", model)
print(pipe("def hello_world():", max_new_tokens=20, 
           do_sample=False)[0]["generated_text"])
Some weights of the model checkpoint at Salesforce/codegen-350M-mono were not used when initializing CodeGenForCausalLM: ['transformer.h.0.attn.causal_mask', 'transformer.h.1.attn.causal_mask', 'transformer.h.10.attn.causal_mask', 'transformer.h.11.attn.causal_mask', 'transformer.h.12.attn.causal_mask', 'transformer.h.13.attn.causal_mask', 'transformer.h.14.attn.causal_mask', 'transformer.h.15.attn.causal_mask', 'transformer.h.16.attn.causal_mask', 'transformer.h.17.attn.causal_mask', 'transformer.h.18.attn.causal_mask', 'transformer.h.19.attn.causal_mask', 'transformer.h.2.attn.causal_mask', 'transformer.h.3.attn.causal_mask', 'transformer.h.4.attn.causal_mask', 'transformer.h.5.attn.causal_mask', 'transformer.h.6.attn.causal_mask', 'transformer.h.7.attn.causal_mask', 'transformer.h.8.attn.causal_mask', 'transformer.h.9.attn.causal_mask']
- This IS expected if you are initializing CodeGenForCausalLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing CodeGenForCausalLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Device set to use cuda:0
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
def hello_world():
    print("Hello World")

hello_world()

# 파
Model before:

 CodeGenForCausalLM(
  (transformer): CodeGenModel(
    (wte): Embedding(51200, 1024)
    (drop): Dropout(p=0.0, inplace=False)
    (h): ModuleList(
      (0-19): 20 x CodeGenBlock(
        (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (attn): CodeGenAttention(
          (attn_dropout): Dropout(p=0.0, inplace=False)
          (resid_dropout): Dropout(p=0.0, inplace=False)
          (qkv_proj): Linear(in_features=1024, out_features=3072, bias=False)
          (out_proj): Linear(in_features=1024, out_features=1024, bias=False)
        )
        (mlp): CodeGenMLP(
          (fc_in): Linear(in_features=1024, out_features=4096, bias=True)
          (fc_out): Linear(in_features=4096, out_features=1024, bias=True)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.0, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=1024, out_features=51200, bias=True)
)
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Model after:

 CodeGenForCausalLM(
  (transformer): CodeGenModel(
    (wte): Embedding(51200, 1024)
    (drop): Dropout(p=0.0, inplace=False)
    (h): ModuleList(
      (0-19): 20 x CodeGenBlock(
        (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (attn): CodeGenAttention(
          (attn_dropout): Dropout(p=0.0, inplace=False)
          (resid_dropout): Dropout(p=0.0, inplace=False)
          (qkv_proj): W8A16LinearLayer()
          (out_proj): W8A16LinearLayer()
        )
        (mlp): CodeGenMLP(
          (fc_in): W8A16LinearLayer()
          (fc_out): W8A16LinearLayer()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.0, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=1024, out_features=51200, bias=True)
)
def hello_world():
    print("Hello World")

# hello_world()

# def hello_
  • Above code snippet modifies the model inplace
  • Also don’t try to change the lm_head otherwise it will not give the desired results
  • All the rounding errors can sum up once you start generating a lot of tokens, until may be all of these errors get super large so that it affects the model’s performance

Load your Quantized Weights from HuggingFace Hub

  • The idea is to quantize weights on bigger instance and then push it back to huggingface. So that we don’t have to load and quantize again and again.
  • Then use meta device from pytorch to load the skeleton of the model instead of loading the whole model itself.
  • Replace the original layers with the quantized layers
  • Load the quantized weights from huggingfacehub
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

model_id = "facebook/opt-125m"

model = AutoModelForCausalLM.from_pretrained(
    model_id, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True)
tokenizer = AutoTokenizer.from_pretrained(model_id)

replace_linear_with_target_and_quantize(model, 
                             W8A16LinearLayer, 
                                   ["lm_head"])

model

quantized_state_dict = model.state_dict()
torch.save(quantized_state_dict, r"C:\wsl\random\models\quantized_state_dict.pth")

How to upload on HF

from huggingface_hub import HfApi, create_repo

YOUR_HF_USERNAME = “” your_repo_id = f”{YOUR_HF_USERNAME}/opt-125m-quantized-dlai”

api = HfApi()

create_repo(your_repo_id)

api.upload_file( path_or_fileobj=“quantized_state_dict.pth”, path_in_repo=“quantized_state_dict.pth”, repo_id=your_repo_id )

import torch
from transformers import OPTForCausalLM, AutoTokenizer, AutoConfig

model_id = "facebook/opt-125m"
config = AutoConfig.from_pretrained(model_id)

with torch.device("meta"):
  model = OPTForCausalLM(config)

tokenizer = AutoTokenizer.from_pretrained(model_id)

for param in model.parameters():
  print(param)
  break
Parameter containing:
tensor(..., device='meta', size=(50272, 768), requires_grad=True)
model
OPTForCausalLM(
  (model): OPTModel(
    (decoder): OPTDecoder(
      (embed_tokens): Embedding(50272, 768, padding_idx=1)
      (embed_positions): OPTLearnedPositionalEmbedding(2050, 768)
      (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (layers): ModuleList(
        (0-11): 12 x OPTDecoderLayer(
          (self_attn): OPTSdpaAttention(
            (k_proj): Linear(in_features=768, out_features=768, bias=True)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (activation_fn): ReLU()
          (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (fc2): Linear(in_features=3072, out_features=768, bias=True)
          (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        )
      )
    )
  )
  (lm_head): Linear(in_features=768, out_features=50272, bias=False)
)
replace_linear_with_target(model, W8A16LinearLayer, ["lm_head"])
model
OPTForCausalLM(
  (model): OPTModel(
    (decoder): OPTDecoder(
      (embed_tokens): Embedding(50272, 768, padding_idx=1)
      (embed_positions): OPTLearnedPositionalEmbedding(2050, 768)
      (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (layers): ModuleList(
        (0-11): 12 x OPTDecoderLayer(
          (self_attn): OPTSdpaAttention(
            (k_proj): W8A16LinearLayer()
            (v_proj): W8A16LinearLayer()
            (q_proj): W8A16LinearLayer()
            (out_proj): W8A16LinearLayer()
          )
          (activation_fn): ReLU()
          (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (fc1): W8A16LinearLayer()
          (fc2): W8A16LinearLayer()
          (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        )
      )
    )
  )
  (lm_head): Linear(in_features=768, out_features=50272, bias=False)
)

If loading from HF

from huggingface_hub import hf_hub_download

state_dict_cache_path = hf_hub_download( “ybelkada/opt-125m-quantized-dlai”, “quantized_state_dict.pth” )

state_dict = torch.load(r"C:\wsl\random\models\quantized_state_dict.pth")
model.load_state_dict(state_dict, strict=True, assign=True)
C:\Users\rites\AppData\Local\Temp\ipykernel_32032\1779271493.py:1: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  state_dict = torch.load(r"C:\wsl\random\models\quantized_state_dict.pth")
<All keys matched successfully>
from transformers import pipeline

pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
pipe("Hello today I am", max_new_tokens=40)

pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
pipe("Hello today I am giving a course about", max_new_tokens=10)
Device set to use cuda:0
Device set to use cuda:0
[{'generated_text': 'Hello today I am giving a course about the history of the world and the history of the'}]

Weights Packing

  • Weights packing is important for storing quantized weights, because torch.int4 is not available as of today, so we need to store and load the weights in int8
  • This is not ideal because:
    • tensor will occupy 8-bit per datapoint and might add a considerable overhead for large models
    • There would be no point of quantizing to 2/4 bits becuase we are still using 8-bit
  • So, we need to pack values
  • Consider a tensor with 4 values each with 2-bit(0,1,2,3) precision but stored in 8-bit
    • tensor = torch.tensor([1,0,3,2], dtype=torch.uint8)
    • 1:- 00000001
    • 0:- 00000000
    • 3:- 00000011
    • 2:- 00000010
  • We can pack all these values into a single 8-bit value as 177
    • 177:- 10110001
  • Adavantages:-
    • It reflects the true memory footprint of the quantized weights Disadvantages:-
    • The unpacked tensors need to be a shape with a multiple of 8//nbits
    • It needs to unpack before performing an operation

Packing 2-bit Weights

import torch

# Example Tensor: [1, 0, 3, 2]
    # 1 0 3 2 - 01 00 11 10

    # Starting point of packed int8 Tensor
    # [0000 0000]

    ##### First Iteration Start:
    # packed int8 Tensor State: [0000 0000]
    # 1 = 0000 0001
    # 0000 0001
    # No left shifts in the First Iteration
    # After bit-wise OR operation between 0000 0000 and 0000 0001:
    # packed int8 Tensor State: 0000 0001
    ##### First Iteration End

    ##### Second Iteration Start:
    # packed int8 Tensor State: [0000 0001]
    # 0 = 0000 0000
    # 0000 0000
    # 2 left shifts:
    # [0000 0000] (1 shift)-> 0000 0000 (2 shift)-> 0000 0000
    # After bit-wise OR operation between 0000 0001 and 0000 0000:
    # packed int8 Tensor State: 0000 0001
    ##### Second Iteration End

    ##### Third Iteration Start:
    # packed int8 Tensor State: [0000 0001]
    # 3 = 0000 0011
    # 0000 0011
    # 4 left shifts:
    # [0000 0011] (1 shift)-> 0000 0110 (2 shift)-> 0000 1100
    # 0000 1100 (3 shift)-> 0001 1000 (4 shift)-> 0011 0000
    # After bit-wise OR operation between 0000 0001 and 0011 0000:
    # packed int8 Tensor State: 0011 0001
    ##### Third Iteration End

    ##### Fourth Iteration Start:
    # packed int8 Tensor State: [0011 0001]
    # 2 = 0000 0010
    # 0000 0010
    # 6 left shifts:
    # [0000 0010] (1 shift)-> 0000 0100 (2 shift)-> 0000 1000
    # 0000 1000 (3 shift)-> 0001 0000 (4 shift)-> 0010 0000
    # 0010 0000 (5 shift)-> 0100 0000 (6 shift)-> 1000 0000
    # After bit-wise OR operation between 0011 0001 and 1000 0000:
    # packed int8 Tensor State: 1011 0001
    ##### Fourth Iteration End

    # Final packed int8 Tensor State: [1011 0001]

def pack_weights(uint8tensor, bits):
    if uint8tensor.shape[0] * bits % 8 != 0:
        raise ValueError(f"The input shape needs to be a mutiple \
        of {8 / bits} - got {uint8tensor.shape[0]}")

    num_values = uint8tensor.shape[0] * bits // 8

    num_steps = 8 // bits

    unpacked_idx = 0

    packed_tensor = torch.zeros((num_values), dtype=torch.uint8)

    # 1 0 3 2 - 01 00 11 10

    # [0000 0000] -> 0000 0001

    # 0000 0001

    # 0000 0000 - 0000 0000

    # 0000 0011 - 0011 0000 - 0011 0001

    # 1011 0001
    
    for i in range(num_values):
        for j in range(num_steps):
            packed_tensor[i] |= uint8tensor[unpacked_idx] << (bits * j)
            unpacked_idx += 1
    return packed_tensor
unpacked_tensor = torch.tensor([1, 0, 3, 2], 
                               dtype=torch.uint8)
pack_weights(unpacked_tensor, 2)
unpacked_tensor = torch.tensor([1, 0, 3, 2, 3, 3, 3, 3], 
                               dtype=torch.uint8)
pack_weights(unpacked_tensor, 2)
tensor([177, 255], dtype=torch.uint8)

Unpacking 2-Bit Weights

import torch

# Example Tensor: [10110001]
    # Which was Originally: 1 0 3 2 - 01 00 11 10

    # Starting point of unpacked Tensor
    # [00000000 00000000 00000000 00000000]

    ##### First Iteration Start:
    # packed int8 Tensor: [10110001]
    # You want to extract 01 from [101100 01]
    # No right shifts in the First Iteration
    # After bit-wise OR operation between 00000000 and 10110001:
    # [10110001 00000000 00000000 00000000]
    # unpacked Tensor state: [10110001 00000000 00000000 00000000]
    ##### First Iteration End

    ##### Second Iteration Start:
    # packed int8 Tensor: [10110001]
    # You want to extract 00 from [1011 00 01]
    # 2 right shifts:
    # [10110001] (1 shift)-> 01011000 (2 shift)-> 00101100
    # After bit-wise OR operation between 00000000 and 00101100:
    # [10110001 00101100 00000000 00000000]
    # unpacked Tensor state: [10110001 00101100 00000000 00000000]
    ##### Second Iteration End

    ##### Third Iteration Start:
    # packed int8 Tensor: [10110001]
    # You want to extract 11 from [10 11 0001]
    # 4 right shifts:
    # [10110001] (1 shift)-> 01011000 (2 shift)-> 00101100
    # 00101100 (3 shift)-> 00010110 (4 shift)-> 00001011
    # After bit-wise OR operation between 00000000 and 00001011:
    # [10110001 00101100 00001011 00000000]
    # unpacked Tensor state: [10110001 00101100 00001011 00000000]
    ##### Third Iteration End

    ##### Fourth Iteration Start:
    # packed int8 Tensor: [10110001]
    # You want to extract 10 from [10 110001]
    # 6 right shifts:
    # [10110001] (1 shift)-> 01011000 (2 shift)-> 00101100
    # 00101100 (3 shift)-> 00010110 (4 shift)-> 00001011
    # 00001011 (5 shift)-> 00000101 (6 shift)-> 00000010
    # After bit-wise OR operation between 00000000 and 00000010:
    # [10110001 00101100 00001011 00000010]
    # unpacked Tensor state: [10110001 00101100 00001011 00000010]
    ##### Fourth Iteration End

    # Last step: Perform masking (bit-wise AND operation)
    # Mask: 00000011
    # Bit-wise AND operation between 
    # unpacked Tensor and 00000011
    # [10110001 00101100 00001011 00000010] <- unpacked tensor
    # [00000011 00000011 00000011 00000011] <- Mask
    # [00000001 00000000 00000011 00000010] <- Result

    # Final
    # unpacked Tensor state: [00000001 00000000 00000011 00000010]

def unpack_weights(uint8tensor, bits):
    num_values = uint8tensor.shape[0] * 8 // bits

    num_steps = 8 // bits

    unpacked_tensor = torch.zeros((num_values), dtype=torch.uint8)

    unpacked_idx = 0

    # 1 0 3 2 - 01 00 11 10

    # [00000000 00000000 00000000 00000000]
    # [10110001 00101100 00001011 00000010]
    # [00000001 00000000 00000011 00000010]

    # 10110001
    # 00000011
    
    # 00000001

    # 1: [10110001]
    # 2: [00101100]
    # 3: [00001011]

    mask = 2 ** bits - 1

    for i in range(uint8tensor.shape[0]):
        for j in range(num_steps):
            unpacked_tensor[unpacked_idx] |= uint8tensor[i] >> (bits * j)
            unpacked_idx += 1

    unpacked_tensor &= mask
    return unpacked_tensor

unpacked_tensor = torch.tensor([177, 255], 
                               dtype=torch.uint8)
# Answer should be: torch.tensor([1, 0, 3, 2, 3, 3, 3, 3]
unpack_weights(unpacked_tensor, 2)
tensor([1, 0, 3, 2, 3, 3, 3, 3], dtype=torch.uint8)

Beyond Linear Qauntization

  • Emergent features at scale:- Simply some characteristics or features which appear at scale, when model is large.

  • Features predicted by the model meaning the magnitude of the hidden states started to get large thus making the classic quantization schemes quite obsolete, which led to classic linear quantization algorithms just failing on these models.

  • Now how to deal with outlier features for LLMs

  • Outlier features simply means hidden states with large magnitude.

  • So there are some interesting papers such as LLM.int8, SmoothQuant, AWQ.

    • LLM.int8 separates the matmul in two steps:-
      • For non-outliers (smaller values)
        • Perform matmul in int8, then dequantize it.
      • For outliers (larger values)
        • Perform matmul in classical way(basically in the dtype of hidden states usually half precision and then you combine these results)
    • SmoothQuant
      • Applies A8W8 scheme(quantize weights and activations)
      • Given an input it determines some factor and use it to quantize.
      • migrates the scale variance from activations to weights to reduce the quantization difficulty of activations.
      • the smoothed activation and the adjusted weight are both easy to quantize.
    • AWQ
      • Used a calibration dataset to find out which weights could be responsible of generating outlier features called salient weights.
      • and then use that information to scale the weights before quantization and also use that scale during inference to rescale the input as well.
  • Recent SOTA quantization methods:

    • LL.INT8
    • GPTQ
    • SmoothQuant
    • QLoRA
    • AWQ
    • QuIP#
    • HQQ
    • AQLM
    • ………..
  • Challenges of Quantization

    • Retraining (Quantization Aware Training) [less explored]
    • Limited Hardware support
    • Calibration dataset needed
    • packing/unpacking
  • Some Other resources

    • MIT Han Lab
    • Huggingface transformers quantization docs/blogposts
    • llama.cpp discussions
    • Reddit LocalLlama