-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathutils.py
74 lines (70 loc) · 3.19 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
from torch import nn
from transformers.models.llama.modeling_llama import *
### BitLinear definition Source: https://github.com/microsoft/unilm/blob/master/bitnet/The-Era-of-1-bit-LLMs__Training_Tips_Code_FAQ.pdf
def activation_quant(x):
""" Per−token quantization to 8 bits. No grouping is needed for quantization.
Args:
x: an activation tensor with shape [n, d]
Returns:
y: a quantized activation tensor with shape [n, d]
"""
scale = 127.0 / x.abs().max(dim=-1, keepdim=True).values.clamp_(min=1e-5)
y = (x * scale).round().clamp_(-128, 127) / scale
return y
def weight_quant(w):
""" Per−tensor quantization to 1.58 bits. No grouping is needed for quantization.
Args:
w: a weight tensor with shape [d, k]
Returns:
u: a quantized weight with shape [d, k]
"""
scale = 1.0 / w.abs().mean().clamp_(min=1e-5)
u = (w * scale).round().clamp_(-1, 1) / scale
return u
class BitLinear(nn.Linear):
"""
This is only for training, and kernel optimization is needed for efficiency.
"""
def forward(self, x):
"""
Args:
x: an input tensor with shape [n, d]
Returns:
y: an output tensor with shape [n, d]
"""
w = self.weight # a weight tensor with shape [d, k]
x = x.to(w.device)
RMSNorm = LlamaRMSNorm(x.shape[-1]).to(w.device)
x_norm = RMSNorm(x)
# A trick for implementing Straight−Through−Estimator (STE) using detach()
x_quant = x_norm + (activation_quant(x_norm) - x_norm).detach()
w_quant = w + (weight_quant(w) - w).detach()
y = F.linear(x_quant, w_quant)
return y
"""
Converts a LLamaForCausalLM model to bitnet architecture.
There are two steps to achieve this according to the released training details:
1. Replace all nn.Linear in attention and SwiGLU with BitLinear
2. Remove RMSNorm before attention and SwiGLU because BitLinear has built-in RMSNorm
Args:
model: A LLamaForCausalLM model
copy_weights: Boolean value indicating whether to copy the weights of the linear layers to Bitnet layers. Useful for continued
pretraining.
"""
def convert_to_bitnet(model, copy_weights):
for name, module in model.named_modules():
# Replace linear layers with BitNet
if isinstance(module, LlamaSdpaAttention) or isinstance(module, LlamaMLP):
for child_name, child_module in module.named_children():
if isinstance(child_module, nn.Linear):
bitlinear = BitLinear(child_module.in_features, child_module.out_features, child_module.bias is not None).to(device="cuda:0")
if copy_weights:
bitlinear.weight = child_module.weight
if child_module.bias is not None:
bitlinear.bias = child_module.bias
setattr(module, child_name, bitlinear)
# Remove redundant input_layernorms
elif isinstance(module, LlamaDecoderLayer):
for child_name, child_module in module.named_children():
if isinstance(child_module, LlamaRMSNorm) and child_name == "input_layernorm":
setattr(module, child_name, nn.Identity().to(device="cuda:0"))