Skip to content

Commit

Permalink
add repack_awq_to_optimum_format function (#1998)
Browse files Browse the repository at this point in the history
Signed-off-by: changwangss <[email protected]>
  • Loading branch information
changwangss authored Sep 20, 2024
1 parent 4ee6861 commit ee600ba
Show file tree
Hide file tree
Showing 5 changed files with 313 additions and 25 deletions.
219 changes: 219 additions & 0 deletions neural_compressor/torch/algorithms/weight_only/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Weight-Only utility."""
import numpy as np
import torch

from neural_compressor.torch.utils import accelerator, device_synchronize, logger
Expand Down Expand Up @@ -1228,3 +1229,221 @@ def convert_dtype_str2torch(str_dtype):
return torch.bfloat16
else:
assert False, "Unsupported str dtype {} to torch dtype".format(str_dtype)


# ref reverse reorder from AutoAWQ https://github.com/AutoGPTQ/AutoGPTQ/blob/v0.7.1/auto_gptq/modeling/_utils.py#L491
def awq_reverse_reorder_int_tensor(int_tensor, bits: int):
"""Awq tensor convert tool.
Reverse_reorder_int_tensor
"""
assert bits == 4

int_tensor = int_tensor.T.contiguous()
compress_ratio = 32 // bits
assert int_tensor.shape[-1] % compress_ratio == 0

order_map = [0, 2, 4, 6, 1, 3, 5, 7]
order_tensor = torch.tensor(order_map, dtype=torch.int32, device=int_tensor.device).reshape(1, -1)
order_tensor = order_tensor.repeat(int_tensor.shape[1] // compress_ratio, 1)
order_tensor = order_tensor + torch.arange(
0,
int_tensor.shape[1],
compress_ratio,
dtype=torch.int32,
device=int_tensor.device,
).reshape(-1, 1)
order_tensor = order_tensor.reshape(-1)

reverse_order_tensor = torch.arange(order_tensor.shape[0])[order_tensor]
reverse_order_tensor = reverse_order_tensor[order_tensor]
int_tensor = int_tensor[:, reverse_order_tensor]
return int_tensor


# ref weight unpack from AutoAWQ https://github.com/AutoGPTQ/AutoGPTQ/blob/v0.7.1/auto_gptq/modeling/_utils.py#L516
def unpack_awq(
awq_qweight: torch.Tensor,
awq_qzeros: torch.Tensor,
awq_scales: torch.Tensor,
bits: int,
group_size: int,
):
"""Unpack awq format to actual values.
Args:
awq_qweight (`torch.LongTensor`):
Expected shape: (in_features, out_features // (32 // bits))
awq_qzeros (`torch.LongTensor`):
Expected shape: (in_features // group_size, out_features // (32 // bits))
awq_scales (`torch.LongTensor`):
Expected shape: (in_features // group_size, out_features)
Returns:
fp16_weight (`torch.LongTensor`):
With shape (in_features, out_features).
zeros (`torch.LongTensor`):
With shape (in_features // group_size, out_features).
"""
assert bits == 4

qzeros = awq_qzeros
qweight = awq_qweight
qweight = qweight.T.contiguous()

infeatures = awq_qweight.shape[0]

wf = torch.tensor(list(range(0, 32, bits)), dtype=torch.int32, device=qzeros.device).unsqueeze(0)
zeros = torch.bitwise_right_shift(torch.unsqueeze(qzeros, 2), wf.unsqueeze(0)).to(
torch.int16 if bits == 8 else torch.int8
)

# zeros = zeros + 1

torch.bitwise_and(zeros, (2**bits) - 1, out=zeros)

zeros = zeros.reshape(-1, 1, zeros.shape[1] * zeros.shape[2])

weight = torch.bitwise_right_shift(torch.unsqueeze(qweight, 1), wf.unsqueeze(-1)).to(
torch.int16 if bits == 8 else torch.int8
)
torch.bitwise_and(weight, (2**bits) - 1, out=weight)
weight = weight.reshape(-1, group_size, weight.shape[2])

weight = weight.view(-1, weight.shape[-1])
zeros = zeros.view(-1, zeros.shape[-1])

zeros = zeros.T.contiguous()
zeros = awq_reverse_reorder_int_tensor(zeros, bits)
weight = awq_reverse_reorder_int_tensor(weight, bits)

# Dequantize weights.
scales = awq_scales
zeros = zeros.contiguous()
scale_zeros = zeros * scales

g_idx = torch.tensor([i // group_size for i in range(infeatures)], dtype=torch.int32)
scale_mat = scales[g_idx]
scale_zeros_mat = scale_zeros[g_idx].half()

qdq_weight_T = weight * scale_mat - scale_zeros_mat.half()

fp16_weight = qdq_weight_T.T

return fp16_weight, zeros


# ref weight unpack from AutoAWQ https://github.com/AutoGPTQ/AutoGPTQ/blob/v0.7.1/auto_gptq/modeling/_utils.py#L516
def pack_from_tensors(
unpacked_qweight: torch.Tensor,
unpacked_qzeros: torch.Tensor,
awq_scales: torch.Tensor,
bits: int,
group_size: int,
):
"""Pack the tensor to optimum format.
Args:
unpacked_qweight (`torch.LongTensor`):
Expected shape: (in_features, out_features)
unpacked_qzeros (`torch.LongTensor`):
Expected shape: (in_features // group_size, out_features)
awq_scales (`torch.LongTensor`):
Expected shape: (in_features // group_size, out_features)
Returns:
qweight (`torch.LongTensor`):
With shape (in_features // (32 // bits), out_features)
qzeros (`torch.LongTensor`):
With shape (in_features // group_size, out_features // (32 // bits))
"""
assert bits == 4
W = unpacked_qweight.clone().cpu()

# TODO: This should be checked somehow.
# if isinstance(linear, nn.Conv2d):
# W = W.flatten(1)
# if isinstance(linear, transformers.pytorch_utils.Conv1D):
# W = W.t()

awq_scales = awq_scales.t().contiguous()
unpacked_qzeros = unpacked_qzeros.contiguous()
unpacked_qzeros = unpacked_qzeros.cpu()

awq_scales = awq_scales.cpu()
scale_zeros = unpacked_qzeros.t() * awq_scales
scales = awq_scales.clone()

infeatures = unpacked_qweight.shape[1]

intweight = []
for idx in range(infeatures):
g_idx = idx // group_size

intweight.append(torch.round((W[:, idx] + scale_zeros[:, g_idx]) / scales[:, g_idx]).to(torch.int)[:, None])
intweight = torch.cat(intweight, dim=1)
intweight = intweight.t().contiguous()
intweight = intweight.numpy().astype(np.uint32)

i = 0
row = 0
qweight = np.zeros((intweight.shape[0] // 32 * bits, intweight.shape[1]), dtype=np.uint32)
while row < qweight.shape[0]:
for j in range(i, i + (32 // bits)):
qweight[row] |= intweight[j] << (bits * (j - i))
i += 32 // bits
row += 1

qweight = qweight.astype(np.int32)
qweight = torch.from_numpy(qweight)

unpacked_qzeros = unpacked_qzeros - 1
torch.bitwise_and(unpacked_qzeros, (2**bits) - 1, out=unpacked_qzeros)

unpacked_qzeros = unpacked_qzeros.numpy().astype(np.uint32)
qzeros = np.zeros(
(unpacked_qzeros.shape[0], unpacked_qzeros.shape[1] // 32 * bits),
dtype=np.uint32,
)
i = 0
col = 0
while col < qzeros.shape[1]:
for j in range(i, i + (32 // bits)):
qzeros[:, col] |= unpacked_qzeros[:, j] << (bits * (j - i))
i += 32 // bits
col += 1

qzeros = qzeros.astype(np.int32)
qzeros = torch.from_numpy(qzeros)

return qweight, qzeros


def repack_awq_to_optimum_format(
awq_qweight: torch.Tensor,
awq_qzeros: torch.Tensor,
awq_scales: torch.Tensor,
bits: int,
group_size: int,
):
"""The function to repack_awq_to_optimum_format.
Args:
awq_qweight (`torch.LongTensor`):
Expected shape: (in_features, out_features // (32 // bits))
awq_qzeros (`torch.LongTensor`):
Expected shape: (in_features // group_size, out_features // (32 // bits))
awq_scales (`torch.LongTensor`):
Expected shape: (in_features // group_size, out_features)
Returns:
qweight (`torch.LongTensor`):
With shape (in_features // (32 // bits), out_features)
qzeros (`torch.LongTensor`):
With shape (in_features // group_size, out_features // (32 // bits))
scales (`torch.LongTensor`):
Expected shape: (in_features // group_size, out_features)
"""
unpack_qweight, unpack_qzeros = unpack_awq(awq_qweight, awq_qzeros, awq_scales, bits, group_size)
qweight, qzeros = pack_from_tensors(unpack_qweight, unpack_qzeros, awq_scales, bits, group_size)
return qweight, qzeros, awq_scales
66 changes: 41 additions & 25 deletions neural_compressor/transformers/models/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,13 @@
from neural_compressor.torch.algorithms.weight_only.modules import INCWeightOnlyLinear
from neural_compressor.torch.utils import set_module

from ..quantization.utils import convert_dtype_torch2str, convert_to_quantized_model, replace_linear, save_low_bit
from ..quantization.utils import (
convert_dtype_torch2str,
convert_to_quantized_model,
repack_awq_and_load_state_dict,
replace_linear,
save_low_bit,
)
from ..utils import AutoRoundConfig, AwqConfig, GPTQConfig, RtnConfig, TeqConfig


Expand Down Expand Up @@ -179,6 +185,8 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
) and model.config.model_type == "chatglm":
model = model.float()
model = convert_to_quantized_model(model, quantization_config, device=device_map)
if isinstance(quantization_config, AwqConfig):
quantization_config.backend = "inc"
quantization_config.remove_redundant_parameters()
model.config.quantization_config = quantization_config
else:
Expand Down Expand Up @@ -295,6 +303,7 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs):
quantization_config = GPTQConfig.from_dict(quantization_config)
elif quantization_config["quant_method"] == "autoround":
quantization_config = AutoRoundConfig.from_dict(quantization_config)

assert quantization_config is not None, "Detect this model is not a low-bit model."

if commit_hash is None:
Expand Down Expand Up @@ -613,41 +622,48 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs):

with ContextManagers(init_contexts):
model = model_class(config, *model_args, **kwargs)

if quantization_config.quant_method.value == "awq" and quantization_config.backend != "inc":
if quantization_config.modules_to_not_convert is None:
quantization_config.modules_to_not_convert = ["lm_head", "transformer.output_layer", "embed_out"]
else:
quantization_config.modules_to_not_convert += ["lm_head", "transformer.output_layer", "embed_out"]
model = build_woq_model(model, quantization_config)

if is_sharded:
loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]
else:
# Time to load the checkpoint
state_dict = load_state_dict(resolved_archive_file)
loaded_state_dict_keys = list(state_dict.keys())

# restore default dtype
if dtype_orig is not None:
torch.set_default_dtype(dtype_orig)

(
model,
missing_keys,
unexpected_keys,
mismatched_keys,
offload_index,
error_msgs,
) = model_class._load_pretrained_model(
model,
None,
loaded_state_dict_keys, # XXX: rename?
resolved_archive_file,
pretrained_model_name_or_path,
sharded_metadata=sharded_metadata,
_fast_init=_fast_init,
low_cpu_mem_usage=True,
offload_folder=offload_folder,
offload_state_dict=offload_state_dict,
dtype=torch_dtype,
keep_in_fp32_modules=[],
)
if quantization_config.quant_method.value == "awq" and quantization_config.backend != "inc":
model = repack_awq_and_load_state_dict(
model, resolved_archive_file, loaded_state_dict_keys, quantization_config, is_sharded
)
else:
(
model,
missing_keys,
unexpected_keys,
mismatched_keys,
offload_index,
error_msgs,
) = model_class._load_pretrained_model(
model,
None,
loaded_state_dict_keys, # XXX: rename?
resolved_archive_file,
pretrained_model_name_or_path,
sharded_metadata=sharded_metadata,
_fast_init=_fast_init,
low_cpu_mem_usage=True,
offload_folder=offload_folder,
offload_state_dict=offload_state_dict,
dtype=torch_dtype,
keep_in_fp32_modules=[],
)

# make sure token embedding weights are still tied if needed
model.tie_weights()
Expand Down
38 changes: 38 additions & 0 deletions neural_compressor/transformers/quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from neural_compressor.common.utils import LazyImport, logger
from neural_compressor.torch.algorithms.weight_only.modules import INCWeightOnlyLinear
from neural_compressor.torch.algorithms.weight_only.utility import repack_awq_to_optimum_format
from neural_compressor.torch.quantization import (
AutoRoundConfig,
AWQConfig,
Expand Down Expand Up @@ -654,3 +655,40 @@ def save_low_bit(self, save_directory: Union[str, os.PathLike], push_to_hub: boo
token=kwargs.get("token"),
)
self.quantization_config.save_pretrained(save_directory, **kwargs)


def repack_awq_and_load_state_dict(
model, resolved_archive_file, loaded_state_dict_keys, quantization_config, is_sharded
):
from transformers.modeling_utils import load_state_dict

bits = quantization_config.bits
group_size = quantization_config.group_size

state_dict = {}
if isinstance(resolved_archive_file, str):
resolved_archive_file = [resolved_archive_file]
assert isinstance(resolved_archive_file, list), "Please check if the loading weight is shared."
for shard_file in resolved_archive_file:
assert shard_file.endswith("safetensors"), "Please check the loading weight saved format."
state_dict.update(load_state_dict(shard_file))
assert len(state_dict.keys()) > 0, "Please check the state_dict loading."
for name, module in model.named_modules():
if isinstance(module, INCWeightOnlyLinear):
assert name + ".qweight" in loaded_state_dict_keys, f"Please check the state_dict key { name + '.qweight'}"
assert name + ".qzeros" in loaded_state_dict_keys, f"Please check the state_dict key {name + '.qzeros'}"
assert name + ".scales" in loaded_state_dict_keys, f"Please check the state_dict key { name + '.scales'}"
if name + ".scales" in loaded_state_dict_keys:
awq_qweight = state_dict[name + ".qweight"]
awq_qzeros = state_dict[name + ".qzeros"]
awq_scales = state_dict[name + ".scales"]
qweight, qzeros, awq_scales = repack_awq_to_optimum_format(
awq_qweight, awq_qzeros, awq_scales, bits, group_size
)
state_dict[name + ".qweight"] = qweight
state_dict[name + ".qzeros"] = qzeros
state_dict[name + ".scales"] = awq_scales

model.load_state_dict(state_dict, strict=False, assign=True)

return model
2 changes: 2 additions & 0 deletions neural_compressor/transformers/utils/quantization_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,7 @@ def __init__(
zero_point: bool = True,
absorb_layer_dict: dict = {},
quant_lm_head: bool = False,
backend: str = None,
**kwargs,
):
self.quant_method = QuantizationMethod.AWQ
Expand All @@ -427,6 +428,7 @@ def __init__(
self.seq_len = seq_len
self.absorb_layer_dict = absorb_layer_dict
self.quant_lm_head = quant_lm_head
self.backend = backend
self.modules_to_not_convert = kwargs.get(
"modules_to_not_convert", ["lm_head", "transformer.output_layer", "embed_out"]
)
Expand Down
Loading

0 comments on commit ee600ba

Please sign in to comment.