Support llm-awq backend (#9856)
* Support for LLM-AWQ Backend * fix * Update README.md * Add awqconfig * modify init * update * support llm-awq * fix style * fix style * update * fix AwqBackendPackingMethod not found error * fix style * update README * fix style --------- Co-authored-by: Uxito-Ada <414416158@qq.com> Co-authored-by: Heyang Sun <60865256+Uxito-Ada@users.noreply.github.com> Co-authored-by: cyita <yitastudy@gmail.com>
This commit is contained in:
parent
fea6f16057
commit
146076bdb5
7 changed files with 78 additions and 32 deletions
|
|
@ -4,6 +4,8 @@ This example shows how to directly run 4-bit AWQ models using BigDL-LLM on Intel
|
|||
|
||||
## Verified Models
|
||||
|
||||
### Auto-AWQ Backend
|
||||
|
||||
- [Llama-2-7B-Chat-AWQ](https://huggingface.co/TheBloke/Llama-2-7B-Chat-AWQ)
|
||||
- [CodeLlama-7B-AWQ](https://huggingface.co/TheBloke/CodeLlama-7B-AWQ)
|
||||
- [Mistral-7B-Instruct-v0.1-AWQ](https://huggingface.co/TheBloke/Mistral-7B-Instruct-v0.1-AWQ)
|
||||
|
|
@ -15,6 +17,10 @@ This example shows how to directly run 4-bit AWQ models using BigDL-LLM on Intel
|
|||
- [Yi-34B-AWQ](https://huggingface.co/TheBloke/Yi-34B-AWQ)
|
||||
- [Mixtral-8x7B-Instruct-v0.1-AWQ](https://huggingface.co/ybelkada/Mixtral-8x7B-Instruct-v0.1-AWQ)
|
||||
|
||||
### llm-AWQ Backend
|
||||
|
||||
- [vicuna-7b-1.5-awq](https://huggingface.co/ybelkada/vicuna-7b-1.5-awq)
|
||||
|
||||
## Requirements
|
||||
|
||||
To run these examples with BigDL-LLM, we have some recommended requirements for your machine, please refer to [here](../../../README.md#system-support) for more information.
|
||||
|
|
|
|||
|
|
@ -30,7 +30,6 @@ conda activate llm
|
|||
pip install --pre --upgrade bigdl-llm[all] # install the latest bigdl-llm nightly build with 'all' option
|
||||
pip install transformers==4.34.0 # upgrade transformers
|
||||
```
|
||||
|
||||
### 2. Run
|
||||
After setting up the Python environment, you could run the example by following steps.
|
||||
|
||||
|
|
|
|||
|
|
@ -226,6 +226,7 @@ def _replace_with_awq_layers(model, awq_config: AwqConfig):
|
|||
q_linear = q_linear_module.from_linear(module,
|
||||
awq_config.bits,
|
||||
awq_config.group_size,
|
||||
awq_config.backend,
|
||||
True)
|
||||
q_linear.to(next(layer.parameters()).device)
|
||||
set_op_by_name(layer, name, q_linear)
|
||||
|
|
|
|||
|
|
@ -83,7 +83,7 @@ class AwqConfig(QuantizationConfigMixin):
|
|||
self.bits = bits
|
||||
self.group_size = group_size
|
||||
self.zero_point = zero_point
|
||||
self.version = version
|
||||
self.version = version.lower()
|
||||
self.backend = backend
|
||||
self.modules_to_not_convert = modules_to_not_convert
|
||||
|
||||
|
|
@ -93,9 +93,11 @@ class AwqConfig(QuantizationConfigMixin):
|
|||
r"""
|
||||
Safety checker that arguments are correct
|
||||
"""
|
||||
invalidInputError(self.backend == AwqBackendPackingMethod.AUTOAWQ,
|
||||
invalidInputError(self.backend == AwqBackendPackingMethod.AUTOAWQ
|
||||
or self.backend == AwqBackendPackingMethod.LLMAWQ,
|
||||
"Only supported quantization backends in "
|
||||
f"{AwqBackendPackingMethod.AUTOAWQ} - "
|
||||
f"{AwqBackendPackingMethod.AUTOAWQ} and "
|
||||
f"{AwqBackendPackingMethod.LLMAWQ} and "
|
||||
f"not recognized backend {self.backend}")
|
||||
|
||||
invalidInputError(self.version in [AWQLinearVersion.GEMM, AWQLinearVersion.GEMV],
|
||||
|
|
|
|||
|
|
@ -44,6 +44,8 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from bigdl.llm.utils.common import invalidOperationError, invalidInputError
|
||||
from transformers import AwqConfig
|
||||
from transformers.utils.quantization_config import AwqBackendPackingMethod
|
||||
|
||||
|
||||
def make_divisible(c, divisor):
|
||||
|
|
@ -67,7 +69,7 @@ def calculate_zeros_width(in_features, group_size=128, pack_num=8):
|
|||
|
||||
|
||||
class WQLinear_GEMM(nn.Module):
|
||||
def __init__(self, bits, group_size, in_features, out_features, bias, dev):
|
||||
def __init__(self, bits, group_size, in_features, out_features, bias, dev, backend):
|
||||
super().__init__()
|
||||
|
||||
invalidOperationError(bits == 4, "Only 4-bit are supported for now.")
|
||||
|
|
@ -76,27 +78,41 @@ class WQLinear_GEMM(nn.Module):
|
|||
self.out_features = out_features
|
||||
self.bits = bits
|
||||
self.group_size = group_size if group_size != -1 else in_features
|
||||
|
||||
self.wf = (torch.tensor([0, 4, 1, 5, 2, 6, 3, 7],
|
||||
dtype=torch.int32) * self.bits).unsqueeze(0)
|
||||
self.backend = backend
|
||||
|
||||
# quick sanity check (make sure aligment)
|
||||
invalidInputError(self.in_features % self.group_size == 0,
|
||||
f"Invalid in_features number {self.in_features}.")
|
||||
invalidInputError(out_features % (32 // self.bits) == 0,
|
||||
f"Invalid out_features number {out_features}.")
|
||||
|
||||
self.register_buffer('qweight',
|
||||
torch.zeros((in_features,
|
||||
out_features // (32 // self.bits)),
|
||||
dtype=torch.int32, device=dev))
|
||||
self.register_buffer('qzeros',
|
||||
torch.zeros((in_features // self.group_size,
|
||||
out_features // (32 // self.bits)),
|
||||
dtype=torch.int32, device=dev))
|
||||
self.register_buffer('scales',
|
||||
torch.zeros((in_features // self.group_size, out_features),
|
||||
dtype=torch.float16, device=dev))
|
||||
if backend == AwqBackendPackingMethod.LLMAWQ:
|
||||
self.wf = (torch.tensor([0, 1, 2, 3, 4, 5, 6, 7],
|
||||
dtype=torch.int32) * self.bits).unsqueeze(0)
|
||||
self.register_buffer('qweight',
|
||||
torch.zeros((out_features,
|
||||
in_features // (32 // self.bits)),
|
||||
dtype=torch.int32, device=dev))
|
||||
zeros_width = calculate_zeros_width(in_features, self.group_size)
|
||||
self.register_buffer('qzeros',
|
||||
torch.zeros((out_features, zeros_width),
|
||||
dtype=torch.int32, device=dev))
|
||||
self.register_buffer('scales',
|
||||
torch.zeros((out_features, zeros_width * (32 // self.bits)),
|
||||
dtype=torch.float16, device=dev))
|
||||
elif backend == AwqBackendPackingMethod.AUTOAWQ:
|
||||
self.wf = (torch.tensor([0, 4, 1, 5, 2, 6, 3, 7],
|
||||
dtype=torch.int32) * self.bits).unsqueeze(0)
|
||||
self.register_buffer('qweight',
|
||||
torch.zeros((in_features,
|
||||
out_features // (32 // self.bits)),
|
||||
dtype=torch.int32, device=dev))
|
||||
self.register_buffer('qzeros',
|
||||
torch.zeros((in_features // self.group_size,
|
||||
out_features // (32 // self.bits)),
|
||||
dtype=torch.int32, device=dev))
|
||||
self.register_buffer('scales',
|
||||
torch.zeros((in_features // self.group_size, out_features),
|
||||
dtype=torch.float16, device=dev))
|
||||
if bias:
|
||||
self.register_buffer('bias', torch.zeros((out_features), dtype=torch.float16,
|
||||
device=dev))
|
||||
|
|
@ -104,9 +120,10 @@ class WQLinear_GEMM(nn.Module):
|
|||
self.bias = None
|
||||
|
||||
@classmethod
|
||||
def from_linear(cls, linear, bits, group_size, init_only=False, scales=None, zeros=None):
|
||||
def from_linear(cls, linear, bits, group_size, backend,
|
||||
init_only=False, scales=None, zeros=None):
|
||||
awq_linear = cls(bits, group_size, linear.in_features, linear.out_features,
|
||||
linear.bias is not None, linear.weight.device)
|
||||
linear.bias is not None, linear.weight.device, backend)
|
||||
if init_only: # just prepare for loading sd
|
||||
return awq_linear
|
||||
|
||||
|
|
@ -139,7 +156,10 @@ class WQLinear_GEMM(nn.Module):
|
|||
|
||||
for col in range(intweight.shape[1] // pack_num):
|
||||
if awq_linear.bits == 4:
|
||||
order_map = [0, 2, 4, 6, 1, 3, 5, 7]
|
||||
if backend == AwqBackendPackingMethod.AUTOAWQ:
|
||||
order_map = [0, 2, 4, 6, 1, 3, 5, 7]
|
||||
elif backend == AwqBackendPackingMethod.LLMAWQ:
|
||||
order_map = [0, 1, 2, 3, 4, 5, 6, 7]
|
||||
else:
|
||||
invalidOperationError(False, "Only 4-bit are supported for now.")
|
||||
for i in range(pack_num):
|
||||
|
|
@ -153,7 +173,10 @@ class WQLinear_GEMM(nn.Module):
|
|||
|
||||
for col in range(zeros.shape[1] // pack_num):
|
||||
if awq_linear.bits == 4:
|
||||
order_map = [0, 2, 4, 6, 1, 3, 5, 7]
|
||||
if backend == AwqBackendPackingMethod.AUTOAWQ:
|
||||
order_map = [0, 2, 4, 6, 1, 3, 5, 7]
|
||||
elif backend == AwqBackendPackingMethod.LLMAWQ:
|
||||
order_map = [0, 1, 2, 3, 4, 5, 6, 7]
|
||||
else:
|
||||
invalidOperationError(False, "Only 4-bit are supported for now.")
|
||||
for i in range(pack_num):
|
||||
|
|
@ -211,7 +234,8 @@ class WQLinear_GEMV(nn.Module):
|
|||
self.bias = None
|
||||
|
||||
@classmethod
|
||||
def from_linear(cls, linear, bits, group_size, init_only=False, scales=None, zeros=None):
|
||||
def from_linear(cls, linear, bits, group_size, backend,
|
||||
init_only=False, scales=None, zeros=None):
|
||||
awq_linear = cls(bits, group_size, linear.in_features, linear.out_features,
|
||||
linear.bias is not None, linear.weight.device)
|
||||
if init_only: # just prepare for loading sd
|
||||
|
|
@ -246,7 +270,10 @@ class WQLinear_GEMV(nn.Module):
|
|||
|
||||
for col in range(intweight.shape[1] // pack_num):
|
||||
if awq_linear.bits == 4:
|
||||
order_map = [0, 1, 2, 3, 4, 5, 6, 7]
|
||||
if backend == AwqBackendPackingMethod.AUTOAWQ:
|
||||
order_map = [0, 2, 4, 6, 1, 3, 5, 7]
|
||||
elif backend == AwqBackendPackingMethod.LLMAWQ:
|
||||
order_map = [0, 1, 2, 3, 4, 5, 6, 7]
|
||||
else:
|
||||
invalidOperationError(False, "Only 4-bit are supported for now.")
|
||||
for i in range(pack_num):
|
||||
|
|
@ -263,7 +290,10 @@ class WQLinear_GEMV(nn.Module):
|
|||
|
||||
for col in range((zeros.shape[1] + pack_num - 1) // pack_num):
|
||||
if awq_linear.bits == 4:
|
||||
order_map = [0, 1, 2, 3, 4, 5, 6, 7]
|
||||
if backend == AwqBackendPackingMethod.AUTOAWQ:
|
||||
order_map = [0, 2, 4, 6, 1, 3, 5, 7]
|
||||
elif backend == AwqBackendPackingMethod.LLMAWQ:
|
||||
order_map = [0, 1, 2, 3, 4, 5, 6, 7]
|
||||
else:
|
||||
invalidOperationError(False, "Only 4-bit are supported for now.")
|
||||
for i in range(pack_num):
|
||||
|
|
|
|||
|
|
@ -74,9 +74,9 @@ def is_deepspeed_available():
|
|||
if is_auto_gptq_available():
|
||||
from auto_gptq.utils.peft_utils import QuantLinearCuda, QuantLinearCudaOld
|
||||
|
||||
|
||||
if is_auto_awq_available():
|
||||
from bigdl.llm.transformers.awq.linear import WQLinear_GEMM
|
||||
from transformers.utils.quantization_config import AwqBackendPackingMethod
|
||||
|
||||
|
||||
def is_linear_module(module):
|
||||
|
|
@ -118,7 +118,7 @@ def is_linear_module(module):
|
|||
return result, (in_features, out_features, mp_group)
|
||||
|
||||
|
||||
def convert_gptq(module, awq=False):
|
||||
def convert_gptq(module, awq=False, llm_awq=False):
|
||||
from bigdl.llm.transformers.low_bit_linear import get_block_size
|
||||
Q4_1 = get_block_size("asym_int4")
|
||||
|
||||
|
|
@ -139,6 +139,8 @@ def convert_gptq(module, awq=False):
|
|||
module.wf.unsqueeze(0)).to(torch.int16 if module.bits == 8 else torch.int8)
|
||||
weight = torch.bitwise_and(weight, (2 ** module.bits) - 1)
|
||||
weight = weight.reshape(weight.shape[0], weight.shape[1] * weight.shape[2])
|
||||
if llm_awq:
|
||||
weight = weight.t()
|
||||
else:
|
||||
weight = torch.bitwise_right_shift(
|
||||
torch.unsqueeze(module.qweight, 1).expand(-1, 32 // module.bits, -1),
|
||||
|
|
@ -155,6 +157,12 @@ def convert_gptq(module, awq=False):
|
|||
weight = torch.bitwise_or(weight[:, :, :, 0], weight[:, :, :, 1]).contiguous()
|
||||
|
||||
# convert zeros to ggml format
|
||||
if llm_awq:
|
||||
real_scale_num = module.in_features // module.group_size
|
||||
zeros = zeros[:, : real_scale_num]
|
||||
scales = scales[:, : real_scale_num]
|
||||
zeros = zeros.t()
|
||||
scales = scales.t()
|
||||
zeros = zeros.reshape(-1, 1, zeros.shape[1]).permute(2, 0, 1)\
|
||||
.unsqueeze(2)\
|
||||
.expand(-1, -1, module.group_size//Q4_1, -1)\
|
||||
|
|
@ -200,6 +208,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
|||
new_linear = None
|
||||
is_gptq = is_auto_gptq_available() and isinstance(module, QuantLinearCudaOld)
|
||||
is_awq = is_auto_awq_available() and isinstance(module, WQLinear_GEMM)
|
||||
is_llm_awq = is_awq and module.backend == AwqBackendPackingMethod.LLMAWQ
|
||||
if is_gptq or is_awq:
|
||||
has_bias = module.bias is not None and module.bias.abs().sum() != 0
|
||||
new_linear = LowBitLinear(
|
||||
|
|
@ -213,7 +222,8 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
|||
invalidInputError(device.type != "meta",
|
||||
"converting from meta device is not supported")
|
||||
# Copy the weights
|
||||
paramsLowBit = FP4Params(data=convert_gptq(module, awq=is_awq),
|
||||
paramsLowBit = FP4Params(data=convert_gptq(module, awq=is_awq,
|
||||
llm_awq=is_llm_awq),
|
||||
requires_grad=False,
|
||||
quantized=True,
|
||||
_shape=(out_features, in_features),
|
||||
|
|
|
|||
|
|
@ -212,8 +212,6 @@ class _BaseAutoModelClass:
|
|||
"Only 4-bit awq is supported in bigdl-llm.")
|
||||
invalidInputError(awq_config.version == "gemm",
|
||||
"Only gemm version is supported in bigdl-llm.")
|
||||
invalidInputError(awq_config.backend == "autoawq",
|
||||
"Only autoawq backend is supported in bigdl-llm.")
|
||||
invalidInputError(awq_config.zero_point,
|
||||
"Only awq zero_point = True is supported in bigdl-llm.")
|
||||
if load_in_low_bit is not None:
|
||||
|
|
|
|||
Loading…
Reference in a new issue