import torch
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
#Helper functions to visualize
def plot_matrix(tensor, ax, title, vmin=0, vmax=1, cmap=None):
"""
Plot a heatmap of tensors using seaborn
"""
sns.heatmap(tensor.cpu().numpy(), ax=ax, vmin=vmin, vmax=vmax, cmap=cmap, annot=True, fmt=".2f", cbar=False)
ax.set_title(title)
ax.set_yticklabels([])
ax.set_xticklabels([])
def plot_quantization_errors(original_tensor, quantized_tensor, dequantized_tensor, dtype = torch.int8, n_bits = 8):
"""
A method that plots 4 matrices, the original tensor, the quantized tensor
the de-quantized tensor and the error tensor.
"""
# Get a figure of 4 plots
fig, axes = plt.subplots(1, 4, figsize=(15, 4))
# Plot the first matrix
plot_matrix(original_tensor, axes[0], 'Original Tensor', cmap=ListedColormap(['white']))
# Get the quantization range and plot the quantized tensor
q_min, q_max = torch.iinfo(dtype).min, torch.iinfo(dtype).max
plot_matrix(quantized_tensor, axes[1], f'{n_bits}-bit Linear Quantized Tensor', vmin=q_min, vmax=q_max, cmap='coolwarm')
# Plot the de-quantized tensors
plot_matrix(dequantized_tensor, axes[2], 'Dequantized Tensor', cmap='coolwarm')
# Get the quantization errors
q_error_tensor = abs(original_tensor - dequantized_tensor)
plot_matrix(q_error_tensor, axes[3], 'Quantization Error Tensor', cmap=ListedColormap(['white']))
fig.tight_layout()
plt.show()
def linear_q_with_scale_and_zero_point(
tensor, scale, zero_point, dtype = torch.int8):
scaled_and_shifted_tensor = tensor / scale + zero_point
rounded_tensor = torch.round(scaled_and_shifted_tensor)
q_min = torch.iinfo(dtype).min
q_max = torch.iinfo(dtype).max
q_tensor = rounded_tensor.clamp(q_min,q_max).to(dtype)
return q_tensor
test_tensor=torch.tensor(
[[191.6, -13.5, 728.6],
[92.14, 295.5, -184],
[0, 684.6, 245.5]])
scale = 3.5
zero_point = -70
quantized_tensor = linear_q_with_scale_and_zero_point(
test_tensor, scale, zero_point)
def linear_dequantization(quantized_tensor, scale, zero_point):
return scale * (quantized_tensor.float() - zero_point)
dequantized_tensor = linear_dequantization(
quantized_tensor, scale, zero_point)
plot_quantization_errors(test_tensor, quantized_tensor,
dequantized_tensor)Quantization in Depth
This is completely based on Quantization in Depth
For the code part, you can checkout this link
Quantize and De-quantize a tensor
- Advantages of Quantization
- Smaller model
- Speed gains
- Memory bandwidth
- Faster operations
- GEMM: General Matrix Multiply(matrix to matrix multiplication)
- GEMV: General Matrix Vector Multiplication (matrix to vector multiplication)
- Challenges of Quantization
- Quantization error
- Retraining (Quantization Aware Training)
- Limited Hardware support
- Calibration dataset needed
- packing/unpacking
- getting q:-
- r = s(q-z) q = int(round(r/s+z))
Get the Scale and Zero-Point
- s = (r_max-r_min)[current_tensor_range]/(q_max-q_min)[datatype_range]
- z = int(round(q_min - r_min/s))
- z and quantized tensor are of the same type
- z is an integer because it represent zero(in the original ‘r’ range) with an integer in the quantized ‘q’ range
- if z goes out of range:-
- z < q_min:-
- z = q_min
- z > q_max:-
- z = q_max
- z < q_min:-
import torch
def get_q_scale_and_zero_point(tensor, dtype=torch.int8):
q_min, q_max = torch.iinfo(dtype).min, torch.iinfo(dtype).max
r_min, r_max = tensor.min().item(), tensor.max().item()
scale = (r_max - r_min) / (q_max - q_min)
zero_point = q_min - (r_min / scale)
# clip the zero_point to fall in [quantized_min, quantized_max]
if zero_point < q_min:
zero_point = q_min
elif zero_point > q_max:
zero_point = q_max
else:
# round and cast to int
zero_point = int(round(zero_point))
return scale, zero_pointSymmetric vs Asymmetric Mode
- Asymmetric Mode:-
- map [r_max, r_min] to [q_max, q_min]
- This is what we have implemented above
- Symmetric Mode:-
- map [-r_max, r_max] to [-q_max, q_max]
- where r_max = max(|tensor|)
- map [-r_max, r_max] to [-q_max, q_max]
Hence, we can simplify the equation to:-
- q = int(round(r/s))
- s = r_max/q_max
import torch
def get_q_scale_symmetric(tensor, dtype=torch.int8):
r_max = tensor.abs().max().item()
q_max = torch.iinfo(dtype).max
# return the scale
return r_max/q_max
def linear_q_symmetric(tensor, dtype=torch.int8):
scale = get_q_scale_symmetric(tensor)
quantized_tensor = linear_q_with_scale_and_zero_point(tensor,
scale=scale,
# in symmetric quantization zero point is = 0
zero_point=0,
dtype=dtype)
return quantized_tensor, scaleTrade-off
Utilization of quantized range:
- when using asymmetric quantization, the quantized range is fully utilized
- When symmetric mode, if the float range is biased towards one side, this will result in a quantized range where a part of the range is dedicated to values that we’ll never see.(e.g ReLU where the output is positive)
Simplicity:
- Symmetric mode is much simpler compared to asymmetric mode.
Memory: We don’t have to store zero-point for symmetric quantization
We use symmetric quantization for 8-bit, but as we go for lower bits such as 2 or 4 bits, we use asymmetric quantization
Finer Granularity for more Precision
- Different granularities
- per tensor
- per channel (along an axis)
- per group (group n elements together)
- The more granular quantization is the more precise it will be.
Per Channel Quantization
- we usually use per channel quantization in int8
def linear_q_symmetric_per_channel(r_tensor, dim, dtype=torch.int8):
output_dim = r_tensor.shape[dim]
# store the scales
scale = torch.zeros(output_dim)
for index in range(output_dim):
sub_tensor = r_tensor.select(dim, index)
scale[index] = get_q_scale_symmetric(sub_tensor, dtype=dtype)
# reshape the scale
scale_shape = [1] * r_tensor.dim()
scale_shape[dim] = -1
scale = scale.view(scale_shape)
quantized_tensor = linear_q_with_scale_and_zero_point(
r_tensor, scale=scale, zero_point=0, dtype=dtype)
return quantized_tensor, scalePer Group Quantization
Group n(e.g. 32, 64, 128) elements together and quantize
Per group quantization can require a lot of memory
- Let’s say we want to quantize a tensor in 4-bit and we choose group_size=32, symmetric mode(z=0), and we store the scales in FP16
- It means that we actually quantizing the tensor in 4.5 bits since we have:
- 4 bit(each element is stored in 4 bit)
- 16/32 bit (scale in 16 bits for every 32 elements)
def linear_q_symmetric_per_group(tensor, group_size,
dtype=torch.int8):
t_shape = tensor.shape
assert t_shape[1] % group_size == 0
assert tensor.dim() == 2
tensor = tensor.view(-1, group_size)
quantized_tensor, scale = linear_q_symmetric_per_channel(
tensor, dim=0, dtype=dtype)
quantized_tensor = quantized_tensor.view(t_shape)
return quantized_tensor, scale
def linear_dequantization_per_group(quantized_tensor, scale,
group_size):
q_shape = quantized_tensor.shape
quantized_tensor = quantized_tensor.view(-1, group_size)
dequantized_tensor = linear_dequantization(quantized_tensor,
scale, 0)
dequantized_tensor = dequantized_tensor.view(q_shape)
return dequantized_tensorQuantizing Weights and Activations for Inference
- Depending on what we quantize, the storage and the computation are not the same.
- W8A32
- If weights are quantized but not the activations, then computation is done floating point (FP16,FP32, BF16)
- We need to dequantize the weights to perform the floating point computation (cast to float32)
- W8A8
- Both are quantized
- Computation is integer based but not supported by all hardware
Custom Build an 8-Bit Quantizer
#W8A16LinearLayer
def w8_a16_forward(weight, input, scales, bias=None):
casted_weights = weight.to(input.dtype)
output = F.linear(input, casted_weights) * scales
if bias is not None:
output = output + bias
return outputimport torch
import torch.nn as nn
import torch.nn.functional as F
class W8A16LinearLayer(nn.Module):
def __init__(self, in_features, out_features,
bias=True, dtype=torch.float32):
super().__init__()
self.int8_weights = nn.Parameter(torch.Tensor([0, 1]
).to(dtype=torch.int8))
try:
W8A16LinearLayer(1, 1)
except Exception as error:
print("\033[91m", type(error).__name__, ": ", error, "\033[0m")When we create nn.Parameters, pytorch expects that parameter where it’s able to compute gradients on it.
The issue is that with PyTorch, you can’t explicitly compute gradients on INT8 tensors.
So above code snippet will give an error saying that only tensors of floating point and complex dtype can require gradients.
So the right approach to save INT8 weights is instead of saving attributes as being an endless parameter, is to call the method called register buffer.
This way instead of storing a parameter, we just store a buffer, meaning we don’t need to compute gradients on the tensor.
You can initialize it with whatever dtype you want.
import torch
import torch.nn as nn
import torch.nn.functional as F
class W8A16LinearLayer(nn.Module):
def __init__(self, in_features, out_features,
bias=True, dtype=torch.float32):
super().__init__()
self.register_buffer(
"int8_weights",
torch.randint(
-128, 127, (out_features, in_features), dtype=torch.int8
)
)
self.register_buffer("scales",
torch.randn((out_features), dtype=dtype)) # We are interested in inference only
if bias:
self.register_buffer("bias",
torch.randn((1, out_features),
dtype=dtype))
else:
self.bias = None
def forward(self, input):
return w8_a16_forward(self.int8_weights,
input, self.scales, self.bias)Quantize a Base Model
import torch
import torch.nn as nn
import torch.nn.functional as F
class W8A16LinearLayer(nn.Module):
def __init__(self, in_features, out_features,
bias=True, dtype=torch.float32):
super().__init__()
self.register_buffer(
"int8_weights",
torch.randint(
-128, 127, (out_features, in_features), dtype=torch.int8
)
)
self.register_buffer("scales",
torch.randn((out_features), dtype=dtype))
if bias:
self.register_buffer("bias",
torch.randn((1, out_features),
dtype=dtype))
else:
self.bias = None
def quantize(self, weights):
w_fp32 = weights.clone().to(torch.float32)
scales = w_fp32.abs().max(dim=-1).values / 127
scales = scales.to(weights.dtype)
int8_weights = torch.round(weights
/scales.unsqueeze(1)).to(torch.int8)
self.int8_weights = int8_weights
self.scales = scales
def forward(self, input):
return w8_a16_forward(self.int8_weights,
input, self.scales, self.bias)
module = W8A16LinearLayer(4, 8)
print("Weights before:\n" , module.int8_weights)
random_matrix = torch.randn((4, 8), dtype=torch.bfloat16)
module.quantize(random_matrix)
print("Weights After:\n" , module.int8_weights)
print("Average quantization error:-",(random_matrix - module.int8_weights
* module.scales.unsqueeze(1)).abs().mean())Replace PyTorch layers with Quantized Layers
- For language models, it better to not quantize the last layer.
import torch
import torch.nn as nn
import torch.nn.functional as F
def replace_linear_with_target(module,
target_class, module_name_to_exclude):
for name, child in module.named_children():
if isinstance(child, nn.Linear) and not \
any([x == name for x in module_name_to_exclude]):
old_bias = child.bias
new_module = target_class(child.in_features,
child.out_features,
old_bias is not None,
child.weight.dtype)
setattr(module, name, new_module)
if old_bias is not None:
getattr(module, name).bias = old_bias
else:
# Recursively call the function for nested modules
replace_linear_with_target(
child, target_class, module_name_to_exclude)
def replace_linear_with_target_and_quantize(module,
target_class, module_name_to_exclude):
for name, child in module.named_children():
if isinstance(child, nn.Linear) and not \
any([x == name for x in module_name_to_exclude]):
old_bias = child.bias
old_weight = child.weight
new_module = target_class(child.in_features,
child.out_features,
old_bias is not None,
child.weight.dtype)
setattr(module, name, new_module) # current module is replaced by new_module
getattr(module, name).quantize(old_weight)
if old_bias is not None:
getattr(module, name).bias = old_bias
else:
# Recursively call the function for nested modules
replace_linear_with_target_and_quantize(child,
target_class, module_name_to_exclude)
class DummyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.emb = torch.nn.Embedding(1, 1)
# Try with bias
self.linear_1 = nn.Linear(1, 1)
# Try without bias
self.linear_2 = nn.Linear(1, 1, bias=False)
# Lm prediction head
self.lm_head = nn.Linear(1, 1, bias=False)
model_1 = DummyModel()
model_2 = DummyModel()
replace_linear_with_target_and_quantize(model_1, W8A16LinearLayer, ["lm_head"])
print("model_1",model_1)
replace_linear_with_target_and_quantize(model_2, W8A16LinearLayer, [])
print("model_2",model_2)Quantize any Open Source PyTorch Model
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
model_id = "Salesforce/codegen-350M-mono"
model = AutoModelForCausalLM.from_pretrained(model_id,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True)
tokenizer = AutoTokenizer.from_pretrained(model_id)
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
print(pipe("def hello_world():", max_new_tokens=20, do_sample=False)[0]["generated_text"])
print("Model before:\n\n", model)
replace_linear_with_target_and_quantize(model,
W8A16LinearLayer, ["lm_head"])
print("Model after:\n\n", model)
print(pipe("def hello_world():", max_new_tokens=20,
do_sample=False)[0]["generated_text"])- Above code snippet modifies the model inplace
- Also don’t try to change the lm_head otherwise it will not give the desired results
- All the rounding errors can sum up once you start generating a lot of tokens, until may be all of these errors get super large so that it affects the model’s performance
Load your Quantized Weights from HuggingFace Hub
- The idea is to quantize weights on bigger instance and then push it back to huggingface. So that we don’t have to load and quantize again and again.
- Then use meta device from pytorch to load the skeleton of the model instead of loading the whole model itself.
- Replace the original layers with the quantized layers
- Load the quantized weights from huggingfacehub
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
model_id = "facebook/opt-125m"
model = AutoModelForCausalLM.from_pretrained(
model_id, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True)
tokenizer = AutoTokenizer.from_pretrained(model_id)
replace_linear_with_target_and_quantize(model,
W8A16LinearLayer,
["lm_head"])
model
quantized_state_dict = model.state_dict()
torch.save(quantized_state_dict, r"/home/ritesh/github_repos/riteshrm.github.io/tmp/quantized_state_dict.pth")How to upload on HF
from huggingface_hub import HfApi, create_repo
YOUR_HF_USERNAME = “” your_repo_id = f”{YOUR_HF_USERNAME}/opt-125m-quantized-dlai”
api = HfApi()
create_repo(your_repo_id)
api.upload_file( path_or_fileobj=“quantized_state_dict.pth”, path_in_repo=“quantized_state_dict.pth”, repo_id=your_repo_id )
import torch
from transformers import OPTForCausalLM, AutoTokenizer, AutoConfig
model_id = "facebook/opt-125m"
config = AutoConfig.from_pretrained(model_id)
with torch.device("meta"):
model = OPTForCausalLM(config)
tokenizer = AutoTokenizer.from_pretrained(model_id)
for param in model.parameters():
print(param)
breakmodelreplace_linear_with_target(model, W8A16LinearLayer, ["lm_head"])
modelIf loading from HF
from huggingface_hub import hf_hub_download
state_dict_cache_path = hf_hub_download( “ybelkada/opt-125m-quantized-dlai”, “quantized_state_dict.pth” )
state_dict = torch.load(r"/home/ritesh/github_repos/riteshrm.github.io/tmp/quantized_state_dict.pth")
model.load_state_dict(state_dict, strict=True, assign=True)from transformers import pipeline
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
pipe("Hello today I am", max_new_tokens=40)
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
pipe("Hello today I am giving a course about", max_new_tokens=10)Weights Packing
- Weights packing is important for storing quantized weights, because torch.int4 is not available as of today, so we need to store and load the weights in int8
- This is not ideal because:
- tensor will occupy 8-bit per datapoint and might add a considerable overhead for large models
- There would be no point of quantizing to 2/4 bits because we are still using 8-bit
- So, we need to pack values
- Consider a tensor with 4 values each with 2-bit(0,1,2,3) precision but stored in 8-bit
- tensor = torch.tensor([1,0,3,2], dtype=torch.uint8)
- 1:- 00000001
- 0:- 00000000
- 3:- 00000011
- 2:- 00000010
- We can pack all these values into a single 8-bit value as 177
- 177:- 10110001
- Advantages:-
- It reflects the true memory footprint of the quantized weights Disadvantages:-
- The unpacked tensors need to be a shape with a multiple of 8//nbits
- It needs to unpack before performing an operation
Packing 2-bit Weights
import torch
# Example Tensor: [1, 0, 3, 2]
# 1 0 3 2 - 01 00 11 10
# Starting point of packed int8 Tensor
# [0000 0000]
##### First Iteration Start:
# packed int8 Tensor State: [0000 0000]
# 1 = 0000 0001
# 0000 0001
# No left shifts in the First Iteration
# After bit-wise OR operation between 0000 0000 and 0000 0001:
# packed int8 Tensor State: 0000 0001
##### First Iteration End
##### Second Iteration Start:
# packed int8 Tensor State: [0000 0001]
# 0 = 0000 0000
# 0000 0000
# 2 left shifts:
# [0000 0000] (1 shift)-> 0000 0000 (2 shift)-> 0000 0000
# After bit-wise OR operation between 0000 0001 and 0000 0000:
# packed int8 Tensor State: 0000 0001
##### Second Iteration End
##### Third Iteration Start:
# packed int8 Tensor State: [0000 0001]
# 3 = 0000 0011
# 0000 0011
# 4 left shifts:
# [0000 0011] (1 shift)-> 0000 0110 (2 shift)-> 0000 1100
# 0000 1100 (3 shift)-> 0001 1000 (4 shift)-> 0011 0000
# After bit-wise OR operation between 0000 0001 and 0011 0000:
# packed int8 Tensor State: 0011 0001
##### Third Iteration End
##### Fourth Iteration Start:
# packed int8 Tensor State: [0011 0001]
# 2 = 0000 0010
# 0000 0010
# 6 left shifts:
# [0000 0010] (1 shift)-> 0000 0100 (2 shift)-> 0000 1000
# 0000 1000 (3 shift)-> 0001 0000 (4 shift)-> 0010 0000
# 0010 0000 (5 shift)-> 0100 0000 (6 shift)-> 1000 0000
# After bit-wise OR operation between 0011 0001 and 1000 0000:
# packed int8 Tensor State: 1011 0001
##### Fourth Iteration End
# Final packed int8 Tensor State: [1011 0001]
def pack_weights(uint8tensor, bits):
if uint8tensor.shape[0] * bits % 8 != 0:
raise ValueError(f"The input shape needs to be a multiple \
of {8 / bits} - got {uint8tensor.shape[0]}")
num_values = uint8tensor.shape[0] * bits // 8
num_steps = 8 // bits
unpacked_idx = 0
packed_tensor = torch.zeros((num_values), dtype=torch.uint8)
# 1 0 3 2 - 01 00 11 10
# [0000 0000] -> 0000 0001
# 0000 0001
# 0000 0000 - 0000 0000
# 0000 0011 - 0011 0000 - 0011 0001
# 1011 0001
for i in range(num_values):
for j in range(num_steps):
packed_tensor[i] |= uint8tensor[unpacked_idx] << (bits * j)
unpacked_idx += 1
return packed_tensor
unpacked_tensor = torch.tensor([1, 0, 3, 2],
dtype=torch.uint8)
pack_weights(unpacked_tensor, 2)
unpacked_tensor = torch.tensor([1, 0, 3, 2, 3, 3, 3, 3],
dtype=torch.uint8)
pack_weights(unpacked_tensor, 2)Unpacking 2-Bit Weights
import torch
# Example Tensor: [10110001]
# Which was Originally: 1 0 3 2 - 01 00 11 10
# Starting point of unpacked Tensor
# [00000000 00000000 00000000 00000000]
##### First Iteration Start:
# packed int8 Tensor: [10110001]
# You want to extract 01 from [101100 01]
# No right shifts in the First Iteration
# After bit-wise OR operation between 00000000 and 10110001:
# [10110001 00000000 00000000 00000000]
# unpacked Tensor state: [10110001 00000000 00000000 00000000]
##### First Iteration End
##### Second Iteration Start:
# packed int8 Tensor: [10110001]
# You want to extract 00 from [1011 00 01]
# 2 right shifts:
# [10110001] (1 shift)-> 01011000 (2 shift)-> 00101100
# After bit-wise OR operation between 00000000 and 00101100:
# [10110001 00101100 00000000 00000000]
# unpacked Tensor state: [10110001 00101100 00000000 00000000]
##### Second Iteration End
##### Third Iteration Start:
# packed int8 Tensor: [10110001]
# You want to extract 11 from [10 11 0001]
# 4 right shifts:
# [10110001] (1 shift)-> 01011000 (2 shift)-> 00101100
# 00101100 (3 shift)-> 00010110 (4 shift)-> 00001011
# After bit-wise OR operation between 00000000 and 00001011:
# [10110001 00101100 00001011 00000000]
# unpacked Tensor state: [10110001 00101100 00001011 00000000]
##### Third Iteration End
##### Fourth Iteration Start:
# packed int8 Tensor: [10110001]
# You want to extract 10 from [10 110001]
# 6 right shifts:
# [10110001] (1 shift)-> 01011000 (2 shift)-> 00101100
# 00101100 (3 shift)-> 00010110 (4 shift)-> 00001011
# 00001011 (5 shift)-> 00000101 (6 shift)-> 00000010
# After bit-wise OR operation between 00000000 and 00000010:
# [10110001 00101100 00001011 00000010]
# unpacked Tensor state: [10110001 00101100 00001011 00000010]
##### Fourth Iteration End
# Last step: Perform masking (bit-wise AND operation)
# Mask: 00000011
# Bit-wise AND operation between
# unpacked Tensor and 00000011
# [10110001 00101100 00001011 00000010] <- unpacked tensor
# [00000011 00000011 00000011 00000011] <- Mask
# [00000001 00000000 00000011 00000010] <- Result
# Final
# unpacked Tensor state: [00000001 00000000 00000011 00000010]
def unpack_weights(uint8tensor, bits):
num_values = uint8tensor.shape[0] * 8 // bits
num_steps = 8 // bits
unpacked_tensor = torch.zeros((num_values), dtype=torch.uint8)
unpacked_idx = 0
# 1 0 3 2 - 01 00 11 10
# [00000000 00000000 00000000 00000000]
# [10110001 00101100 00001011 00000010]
# [00000001 00000000 00000011 00000010]
# 10110001
# 00000011
# 00000001
# 1: [10110001]
# 2: [00101100]
# 3: [00001011]
mask = 2 ** bits - 1
for i in range(uint8tensor.shape[0]):
for j in range(num_steps):
unpacked_tensor[unpacked_idx] |= uint8tensor[i] >> (bits * j)
unpacked_idx += 1
unpacked_tensor &= mask
return unpacked_tensor
unpacked_tensor = torch.tensor([177, 255],
dtype=torch.uint8)
# Answer should be: torch.tensor([1, 0, 3, 2, 3, 3, 3, 3]
unpack_weights(unpacked_tensor, 2)Beyond Linear Quantization
Emergent features at scale:- Simply some characteristics or features which appear at scale, when model is large.
Features predicted by the model meaning the magnitude of the hidden states started to get large thus making the classic quantization schemes quite obsolete, which led to classic linear quantization algorithms just failing on these models.
Now how to deal with outlier features for LLMs
Outlier features simply means hidden states with large magnitude.
So there are some interesting papers such as LLM.int8, SmoothQuant, AWQ.
- LLM.int8 separates the matmul in two steps:-
- For non-outliers (smaller values)
- Perform matmul in int8, then dequantize it.
- For outliers (larger values)
- Perform matmul in classical way(basically in the dtype of hidden states usually half precision and then you combine these results)
- For non-outliers (smaller values)
- SmoothQuant
- Applies A8W8 scheme(quantize weights and activations)
- Given an input it determines some factor and use it to quantize.
- migrates the scale variance from activations to weights to reduce the quantization difficulty of activations.
- the smoothed activation and the adjusted weight are both easy to quantize.
- AWQ
- Used a calibration dataset to find out which weights could be responsible of generating outlier features called salient weights.
- and then use that information to scale the weights before quantization and also use that scale during inference to rescale the input as well.
- LLM.int8 separates the matmul in two steps:-
Recent SOTA quantization methods:
- LL.INT8
- GPTQ
- SmoothQuant
- QLoRA
- AWQ
- QuIP#
- HQQ
- AQLM
- ………..
Challenges of Quantization
- Retraining (Quantization Aware Training) [less explored]
- Limited Hardware support
- Calibration dataset needed
- packing/unpacking
Some Other resources
- MIT Han Lab
- Huggingface transformers quantization docs/blogposts
- llama.cpp discussions
- Reddit LocalLlama