Support llm-awq backend (#9856)

* Support for LLM-AWQ Backend

* fix

* Update README.md

* Add awqconfig

* modify init

* update

* support llm-awq

* fix style

* fix style

* update

* fix AwqBackendPackingMethod not found error

* fix style

* update README

* fix style

---------

Co-authored-by: Uxito-Ada <414416158@qq.com>
Co-authored-by: Heyang Sun <60865256+Uxito-Ada@users.noreply.github.com>
Co-authored-by: cyita <yitastudy@gmail.com>
This commit is contained in:
ZehuaCao 2024-01-09 13:07:32 +08:00 committed by GitHub
parent fea6f16057
commit 146076bdb5
7 changed files with 78 additions and 32 deletions

View file

@ -4,6 +4,8 @@ This example shows how to directly run 4-bit AWQ models using BigDL-LLM on Intel
## Verified Models ## Verified Models
### Auto-AWQ Backend
- [Llama-2-7B-Chat-AWQ](https://huggingface.co/TheBloke/Llama-2-7B-Chat-AWQ) - [Llama-2-7B-Chat-AWQ](https://huggingface.co/TheBloke/Llama-2-7B-Chat-AWQ)
- [CodeLlama-7B-AWQ](https://huggingface.co/TheBloke/CodeLlama-7B-AWQ) - [CodeLlama-7B-AWQ](https://huggingface.co/TheBloke/CodeLlama-7B-AWQ)
- [Mistral-7B-Instruct-v0.1-AWQ](https://huggingface.co/TheBloke/Mistral-7B-Instruct-v0.1-AWQ) - [Mistral-7B-Instruct-v0.1-AWQ](https://huggingface.co/TheBloke/Mistral-7B-Instruct-v0.1-AWQ)
@ -15,6 +17,10 @@ This example shows how to directly run 4-bit AWQ models using BigDL-LLM on Intel
- [Yi-34B-AWQ](https://huggingface.co/TheBloke/Yi-34B-AWQ) - [Yi-34B-AWQ](https://huggingface.co/TheBloke/Yi-34B-AWQ)
- [Mixtral-8x7B-Instruct-v0.1-AWQ](https://huggingface.co/ybelkada/Mixtral-8x7B-Instruct-v0.1-AWQ) - [Mixtral-8x7B-Instruct-v0.1-AWQ](https://huggingface.co/ybelkada/Mixtral-8x7B-Instruct-v0.1-AWQ)
### llm-AWQ Backend
- [vicuna-7b-1.5-awq](https://huggingface.co/ybelkada/vicuna-7b-1.5-awq)
## Requirements ## Requirements
To run these examples with BigDL-LLM, we have some recommended requirements for your machine, please refer to [here](../../../README.md#system-support) for more information. To run these examples with BigDL-LLM, we have some recommended requirements for your machine, please refer to [here](../../../README.md#system-support) for more information.

View file

@ -30,7 +30,6 @@ conda activate llm
pip install --pre --upgrade bigdl-llm[all] # install the latest bigdl-llm nightly build with 'all' option pip install --pre --upgrade bigdl-llm[all] # install the latest bigdl-llm nightly build with 'all' option
pip install transformers==4.34.0 # upgrade transformers pip install transformers==4.34.0 # upgrade transformers
``` ```
### 2. Run ### 2. Run
After setting up the Python environment, you could run the example by following steps. After setting up the Python environment, you could run the example by following steps.

View file

@ -226,6 +226,7 @@ def _replace_with_awq_layers(model, awq_config: AwqConfig):
q_linear = q_linear_module.from_linear(module, q_linear = q_linear_module.from_linear(module,
awq_config.bits, awq_config.bits,
awq_config.group_size, awq_config.group_size,
awq_config.backend,
True) True)
q_linear.to(next(layer.parameters()).device) q_linear.to(next(layer.parameters()).device)
set_op_by_name(layer, name, q_linear) set_op_by_name(layer, name, q_linear)

View file

@ -83,7 +83,7 @@ class AwqConfig(QuantizationConfigMixin):
self.bits = bits self.bits = bits
self.group_size = group_size self.group_size = group_size
self.zero_point = zero_point self.zero_point = zero_point
self.version = version self.version = version.lower()
self.backend = backend self.backend = backend
self.modules_to_not_convert = modules_to_not_convert self.modules_to_not_convert = modules_to_not_convert
@ -93,9 +93,11 @@ class AwqConfig(QuantizationConfigMixin):
r""" r"""
Safety checker that arguments are correct Safety checker that arguments are correct
""" """
invalidInputError(self.backend == AwqBackendPackingMethod.AUTOAWQ, invalidInputError(self.backend == AwqBackendPackingMethod.AUTOAWQ
or self.backend == AwqBackendPackingMethod.LLMAWQ,
"Only supported quantization backends in " "Only supported quantization backends in "
f"{AwqBackendPackingMethod.AUTOAWQ} - " f"{AwqBackendPackingMethod.AUTOAWQ} and "
f"{AwqBackendPackingMethod.LLMAWQ} and "
f"not recognized backend {self.backend}") f"not recognized backend {self.backend}")
invalidInputError(self.version in [AWQLinearVersion.GEMM, AWQLinearVersion.GEMV], invalidInputError(self.version in [AWQLinearVersion.GEMM, AWQLinearVersion.GEMV],

View file

@ -44,6 +44,8 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from bigdl.llm.utils.common import invalidOperationError, invalidInputError from bigdl.llm.utils.common import invalidOperationError, invalidInputError
from transformers import AwqConfig
from transformers.utils.quantization_config import AwqBackendPackingMethod
def make_divisible(c, divisor): def make_divisible(c, divisor):
@ -67,7 +69,7 @@ def calculate_zeros_width(in_features, group_size=128, pack_num=8):
class WQLinear_GEMM(nn.Module): class WQLinear_GEMM(nn.Module):
def __init__(self, bits, group_size, in_features, out_features, bias, dev): def __init__(self, bits, group_size, in_features, out_features, bias, dev, backend):
super().__init__() super().__init__()
invalidOperationError(bits == 4, "Only 4-bit are supported for now.") invalidOperationError(bits == 4, "Only 4-bit are supported for now.")
@ -76,16 +78,30 @@ class WQLinear_GEMM(nn.Module):
self.out_features = out_features self.out_features = out_features
self.bits = bits self.bits = bits
self.group_size = group_size if group_size != -1 else in_features self.group_size = group_size if group_size != -1 else in_features
self.backend = backend
self.wf = (torch.tensor([0, 4, 1, 5, 2, 6, 3, 7],
dtype=torch.int32) * self.bits).unsqueeze(0)
# quick sanity check (make sure aligment) # quick sanity check (make sure aligment)
invalidInputError(self.in_features % self.group_size == 0, invalidInputError(self.in_features % self.group_size == 0,
f"Invalid in_features number {self.in_features}.") f"Invalid in_features number {self.in_features}.")
invalidInputError(out_features % (32 // self.bits) == 0, invalidInputError(out_features % (32 // self.bits) == 0,
f"Invalid out_features number {out_features}.") f"Invalid out_features number {out_features}.")
if backend == AwqBackendPackingMethod.LLMAWQ:
self.wf = (torch.tensor([0, 1, 2, 3, 4, 5, 6, 7],
dtype=torch.int32) * self.bits).unsqueeze(0)
self.register_buffer('qweight',
torch.zeros((out_features,
in_features // (32 // self.bits)),
dtype=torch.int32, device=dev))
zeros_width = calculate_zeros_width(in_features, self.group_size)
self.register_buffer('qzeros',
torch.zeros((out_features, zeros_width),
dtype=torch.int32, device=dev))
self.register_buffer('scales',
torch.zeros((out_features, zeros_width * (32 // self.bits)),
dtype=torch.float16, device=dev))
elif backend == AwqBackendPackingMethod.AUTOAWQ:
self.wf = (torch.tensor([0, 4, 1, 5, 2, 6, 3, 7],
dtype=torch.int32) * self.bits).unsqueeze(0)
self.register_buffer('qweight', self.register_buffer('qweight',
torch.zeros((in_features, torch.zeros((in_features,
out_features // (32 // self.bits)), out_features // (32 // self.bits)),
@ -104,9 +120,10 @@ class WQLinear_GEMM(nn.Module):
self.bias = None self.bias = None
@classmethod @classmethod
def from_linear(cls, linear, bits, group_size, init_only=False, scales=None, zeros=None): def from_linear(cls, linear, bits, group_size, backend,
init_only=False, scales=None, zeros=None):
awq_linear = cls(bits, group_size, linear.in_features, linear.out_features, awq_linear = cls(bits, group_size, linear.in_features, linear.out_features,
linear.bias is not None, linear.weight.device) linear.bias is not None, linear.weight.device, backend)
if init_only: # just prepare for loading sd if init_only: # just prepare for loading sd
return awq_linear return awq_linear
@ -139,7 +156,10 @@ class WQLinear_GEMM(nn.Module):
for col in range(intweight.shape[1] // pack_num): for col in range(intweight.shape[1] // pack_num):
if awq_linear.bits == 4: if awq_linear.bits == 4:
if backend == AwqBackendPackingMethod.AUTOAWQ:
order_map = [0, 2, 4, 6, 1, 3, 5, 7] order_map = [0, 2, 4, 6, 1, 3, 5, 7]
elif backend == AwqBackendPackingMethod.LLMAWQ:
order_map = [0, 1, 2, 3, 4, 5, 6, 7]
else: else:
invalidOperationError(False, "Only 4-bit are supported for now.") invalidOperationError(False, "Only 4-bit are supported for now.")
for i in range(pack_num): for i in range(pack_num):
@ -153,7 +173,10 @@ class WQLinear_GEMM(nn.Module):
for col in range(zeros.shape[1] // pack_num): for col in range(zeros.shape[1] // pack_num):
if awq_linear.bits == 4: if awq_linear.bits == 4:
if backend == AwqBackendPackingMethod.AUTOAWQ:
order_map = [0, 2, 4, 6, 1, 3, 5, 7] order_map = [0, 2, 4, 6, 1, 3, 5, 7]
elif backend == AwqBackendPackingMethod.LLMAWQ:
order_map = [0, 1, 2, 3, 4, 5, 6, 7]
else: else:
invalidOperationError(False, "Only 4-bit are supported for now.") invalidOperationError(False, "Only 4-bit are supported for now.")
for i in range(pack_num): for i in range(pack_num):
@ -211,7 +234,8 @@ class WQLinear_GEMV(nn.Module):
self.bias = None self.bias = None
@classmethod @classmethod
def from_linear(cls, linear, bits, group_size, init_only=False, scales=None, zeros=None): def from_linear(cls, linear, bits, group_size, backend,
init_only=False, scales=None, zeros=None):
awq_linear = cls(bits, group_size, linear.in_features, linear.out_features, awq_linear = cls(bits, group_size, linear.in_features, linear.out_features,
linear.bias is not None, linear.weight.device) linear.bias is not None, linear.weight.device)
if init_only: # just prepare for loading sd if init_only: # just prepare for loading sd
@ -246,6 +270,9 @@ class WQLinear_GEMV(nn.Module):
for col in range(intweight.shape[1] // pack_num): for col in range(intweight.shape[1] // pack_num):
if awq_linear.bits == 4: if awq_linear.bits == 4:
if backend == AwqBackendPackingMethod.AUTOAWQ:
order_map = [0, 2, 4, 6, 1, 3, 5, 7]
elif backend == AwqBackendPackingMethod.LLMAWQ:
order_map = [0, 1, 2, 3, 4, 5, 6, 7] order_map = [0, 1, 2, 3, 4, 5, 6, 7]
else: else:
invalidOperationError(False, "Only 4-bit are supported for now.") invalidOperationError(False, "Only 4-bit are supported for now.")
@ -263,6 +290,9 @@ class WQLinear_GEMV(nn.Module):
for col in range((zeros.shape[1] + pack_num - 1) // pack_num): for col in range((zeros.shape[1] + pack_num - 1) // pack_num):
if awq_linear.bits == 4: if awq_linear.bits == 4:
if backend == AwqBackendPackingMethod.AUTOAWQ:
order_map = [0, 2, 4, 6, 1, 3, 5, 7]
elif backend == AwqBackendPackingMethod.LLMAWQ:
order_map = [0, 1, 2, 3, 4, 5, 6, 7] order_map = [0, 1, 2, 3, 4, 5, 6, 7]
else: else:
invalidOperationError(False, "Only 4-bit are supported for now.") invalidOperationError(False, "Only 4-bit are supported for now.")

View file

@ -74,9 +74,9 @@ def is_deepspeed_available():
if is_auto_gptq_available(): if is_auto_gptq_available():
from auto_gptq.utils.peft_utils import QuantLinearCuda, QuantLinearCudaOld from auto_gptq.utils.peft_utils import QuantLinearCuda, QuantLinearCudaOld
if is_auto_awq_available(): if is_auto_awq_available():
from bigdl.llm.transformers.awq.linear import WQLinear_GEMM from bigdl.llm.transformers.awq.linear import WQLinear_GEMM
from transformers.utils.quantization_config import AwqBackendPackingMethod
def is_linear_module(module): def is_linear_module(module):
@ -118,7 +118,7 @@ def is_linear_module(module):
return result, (in_features, out_features, mp_group) return result, (in_features, out_features, mp_group)
def convert_gptq(module, awq=False): def convert_gptq(module, awq=False, llm_awq=False):
from bigdl.llm.transformers.low_bit_linear import get_block_size from bigdl.llm.transformers.low_bit_linear import get_block_size
Q4_1 = get_block_size("asym_int4") Q4_1 = get_block_size("asym_int4")
@ -139,6 +139,8 @@ def convert_gptq(module, awq=False):
module.wf.unsqueeze(0)).to(torch.int16 if module.bits == 8 else torch.int8) module.wf.unsqueeze(0)).to(torch.int16 if module.bits == 8 else torch.int8)
weight = torch.bitwise_and(weight, (2 ** module.bits) - 1) weight = torch.bitwise_and(weight, (2 ** module.bits) - 1)
weight = weight.reshape(weight.shape[0], weight.shape[1] * weight.shape[2]) weight = weight.reshape(weight.shape[0], weight.shape[1] * weight.shape[2])
if llm_awq:
weight = weight.t()
else: else:
weight = torch.bitwise_right_shift( weight = torch.bitwise_right_shift(
torch.unsqueeze(module.qweight, 1).expand(-1, 32 // module.bits, -1), torch.unsqueeze(module.qweight, 1).expand(-1, 32 // module.bits, -1),
@ -155,6 +157,12 @@ def convert_gptq(module, awq=False):
weight = torch.bitwise_or(weight[:, :, :, 0], weight[:, :, :, 1]).contiguous() weight = torch.bitwise_or(weight[:, :, :, 0], weight[:, :, :, 1]).contiguous()
# convert zeros to ggml format # convert zeros to ggml format
if llm_awq:
real_scale_num = module.in_features // module.group_size
zeros = zeros[:, : real_scale_num]
scales = scales[:, : real_scale_num]
zeros = zeros.t()
scales = scales.t()
zeros = zeros.reshape(-1, 1, zeros.shape[1]).permute(2, 0, 1)\ zeros = zeros.reshape(-1, 1, zeros.shape[1]).permute(2, 0, 1)\
.unsqueeze(2)\ .unsqueeze(2)\
.expand(-1, -1, module.group_size//Q4_1, -1)\ .expand(-1, -1, module.group_size//Q4_1, -1)\
@ -200,6 +208,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
new_linear = None new_linear = None
is_gptq = is_auto_gptq_available() and isinstance(module, QuantLinearCudaOld) is_gptq = is_auto_gptq_available() and isinstance(module, QuantLinearCudaOld)
is_awq = is_auto_awq_available() and isinstance(module, WQLinear_GEMM) is_awq = is_auto_awq_available() and isinstance(module, WQLinear_GEMM)
is_llm_awq = is_awq and module.backend == AwqBackendPackingMethod.LLMAWQ
if is_gptq or is_awq: if is_gptq or is_awq:
has_bias = module.bias is not None and module.bias.abs().sum() != 0 has_bias = module.bias is not None and module.bias.abs().sum() != 0
new_linear = LowBitLinear( new_linear = LowBitLinear(
@ -213,7 +222,8 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
invalidInputError(device.type != "meta", invalidInputError(device.type != "meta",
"converting from meta device is not supported") "converting from meta device is not supported")
# Copy the weights # Copy the weights
paramsLowBit = FP4Params(data=convert_gptq(module, awq=is_awq), paramsLowBit = FP4Params(data=convert_gptq(module, awq=is_awq,
llm_awq=is_llm_awq),
requires_grad=False, requires_grad=False,
quantized=True, quantized=True,
_shape=(out_features, in_features), _shape=(out_features, in_features),

View file

@ -212,8 +212,6 @@ class _BaseAutoModelClass:
"Only 4-bit awq is supported in bigdl-llm.") "Only 4-bit awq is supported in bigdl-llm.")
invalidInputError(awq_config.version == "gemm", invalidInputError(awq_config.version == "gemm",
"Only gemm version is supported in bigdl-llm.") "Only gemm version is supported in bigdl-llm.")
invalidInputError(awq_config.backend == "autoawq",
"Only autoawq backend is supported in bigdl-llm.")
invalidInputError(awq_config.zero_point, invalidInputError(awq_config.zero_point,
"Only awq zero_point = True is supported in bigdl-llm.") "Only awq zero_point = True is supported in bigdl-llm.")
if load_in_low_bit is not None: if load_in_low_bit is not None: