Support the desc_act feature in GPTQ model (#10851)

* support act_order

* update versions

* fix style

* fix bug

* clean up
This commit is contained in:
Yang Wang 2024-04-24 10:17:13 -07:00 committed by GitHub
parent dc27b3bc35
commit 1ce8d7bcd9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 56 additions and 20 deletions

View file

@ -13,9 +13,9 @@ conda create -n llm python=3.11
conda activate llm
# below command will install intel_extension_for_pytorch==2.1.10+xpu as default
pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
pip install transformers==4.34.0
BUILD_CUDA_EXT=0 pip install git+https://github.com/PanQiWei/AutoGPTQ.git@1de9ab6
pip install optimum==0.14.0
pip install transformers==4.37.0
pip install auto_gptq==0.7.1
pip install optimum==1.14.0
```
### 2. Configures OneAPI environment variables

View file

@ -18,7 +18,7 @@ import torch
import time
import argparse
from ipex_llm.transformers import AutoModelForCausalLM
from transformers import LlamaTokenizer, GPTQConfig
from transformers import AutoTokenizer, GPTQConfig
# you could tune the prompt based on your own model,
# here the prompt tuning refers to https://huggingface.co/georgesung/llama2_7b_chat_uncensored#prompt-style
@ -30,7 +30,7 @@ LLAMA2_PROMPT_FORMAT = """### HUMAN:
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for Llama2 model')
parser.add_argument('--repo-id-or-model-path', type=str, default="TheBloke/Llama-2-7B-GPTQ",
parser.add_argument('--repo-id-or-model-path', type=str, default="TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ",
help='The huggingface repo id'
', or the path to the huggingface checkpoint folder')
parser.add_argument('--prompt', type=str, default="What is AI?",
@ -48,8 +48,9 @@ if __name__ == '__main__':
torch_dtype=torch.float,
trust_remote_code=True,).to("xpu")
print(model)
# Load tokenizer
tokenizer = LlamaTokenizer.from_pretrained(model_path, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
# Generate predicted tokens
with torch.inference_mode():

View file

@ -99,6 +99,11 @@ def is_lm_head(name, model_config, out_features):
return False
def is_gptq_linear(module):
return is_auto_gptq_available() and \
(isinstance(module, QuantLinearCuda) or isinstance(module, QuantLinearCudaOld))
def is_linear_module(module):
in_features = None
@ -122,7 +127,7 @@ def is_linear_module(module):
mp_group = None
else:
result = False
elif is_auto_gptq_available() and isinstance(module, QuantLinearCudaOld):
elif is_gptq_linear(module):
in_features = module.infeatures
out_features = module.outfeatures
mp_group = None
@ -153,7 +158,7 @@ def is_linear_module(module):
return result, (in_features, out_features, mp_group)
def convert_gptq(module, awq=False, llm_awq=False):
def convert_gptq(module, awq=False, llm_awq=False, act_order=False):
from ipex_llm.transformers.low_bit_linear import get_block_size
Q4_1 = get_block_size("asym_int4")
@ -164,6 +169,8 @@ def convert_gptq(module, awq=False, llm_awq=False):
module.wf.unsqueeze(0)).to(torch.int16 if module.bits == 8 else torch.int8)
zeros = torch.bitwise_and(zeros, (2 ** module.bits) - 1)
g_id_map = None
if not awq:
zeros = zeros + 1
zeros = zeros.reshape(scales.shape)
@ -183,6 +190,12 @@ def convert_gptq(module, awq=False, llm_awq=False):
weight = torch.bitwise_and(weight, (2 ** module.bits) - 1)
weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2])
if act_order:
invalidInputError(module.g_idx.shape[0] == weight.shape[0],
"g_idx and weight shape mismatch")
_, g_id_map = torch.sort(module.g_idx)
weight = weight[g_id_map, :]
# convert weight to ggml format
weight = weight.reshape(weight.shape[0]//module.group_size, module.group_size, weight.shape[1])
weight = weight.permute(2, 0, 1).reshape(weight.shape[2], -1, 2, Q4_1//2)
@ -219,7 +232,7 @@ def convert_gptq(module, awq=False, llm_awq=False):
weight.view(torch.uint8)], dim=-1)
ggml_weight = ggml_weight.reshape([-1])
return ggml_weight
return ggml_weight, g_id_map
def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
@ -228,7 +241,9 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
imatrix_data=None, embedding_qtype=None,
model_config=None, torch_dtype=torch.float32,
enable_xetla=False,
mixed_precision=False):
mixed_precision=False,
act_order=False,
):
from ipex_llm.transformers.low_bit_linear import LowBitLinear, FP4Params, \
FP16Linear, BF16Linear
from ipex_llm.transformers.embedding import LLMEmbedding, LowBitEmbedding
@ -252,7 +267,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
optimize_lm_head = True
with init_empty_weights():
new_linear = None
is_gptq = is_auto_gptq_available() and isinstance(module, QuantLinearCudaOld)
is_gptq = is_gptq_linear(module)
is_awq = is_auto_awq_available() and isinstance(module, WQLinear_GEMM)
is_llm_awq = is_awq and module.backend == AwqBackendPackingMethod.LLMAWQ
if is_gptq or is_awq:
@ -264,14 +279,20 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
bias=has_bias,
mp_group=mp_group,
enable_xetla=enable_xetla,
optimize_lm_head=optimize_lm_head
optimize_lm_head=optimize_lm_head,
act_order=act_order,
)
device = module.qweight.data.device
invalidInputError(device.type != "meta",
"converting from meta device is not supported")
weight, g_idx_map = convert_gptq(module,
awq=is_awq,
llm_awq=is_llm_awq,
act_order=act_order)
if act_order:
new_linear.g_idx_map = g_idx_map
# Copy the weights
paramsLowBit = FP4Params(data=convert_gptq(module, awq=is_awq,
llm_awq=is_llm_awq),
paramsLowBit = FP4Params(data=weight,
requires_grad=False,
quantized=True,
_shape=(out_features, in_features),
@ -422,7 +443,8 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
model_config=model_config,
torch_dtype=torch_dtype,
enable_xetla=enable_xetla,
mixed_precision=mixed_precision
mixed_precision=mixed_precision,
act_order=act_order,
)
has_been_replaced = _flag or has_been_replaced
return model, has_been_replaced
@ -464,7 +486,7 @@ def replace_with_low_bit_linear_for_module(model, qtype, module_name=None,
in_features, out_features, mp_group = linear_args
with init_empty_weights():
new_linear = None
is_gptq = is_auto_gptq_available() and isinstance(module, QuantLinearCudaOld)
is_gptq = is_gptq_linear(module)
is_awq = is_auto_awq_available() and isinstance(module, WQLinear_GEMM)
is_llm_awq = is_awq and module.backend == AwqBackendPackingMethod.LLMAWQ
if is_gptq or is_awq:
@ -721,6 +743,10 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
if optimize_model:
model = _optimize_pre(model)
act_order = False
if getattr(model, "quantization_method", None) == "gptq":
act_order = model.config.quantization_config.desc_act
# mixed quantization needs model_config to choose custom quantization strategy
model, has_been_replaced = _replace_with_low_bit_linear(
model, qtype, modules_to_not_convert,
@ -731,6 +757,7 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
torch_dtype=torch_dtype,
enable_xetla=enable_xetla,
mixed_precision=mixed_precision,
act_order=act_order,
)
if not has_been_replaced:
warnings.warn(

View file

@ -579,7 +579,7 @@ class MatMulLowBitCPU(torch.autograd.Function):
class LowBitLinear(nn.Linear):
def __init__(self, input_features, output_features, qtype, bias=True,
conver_to_half=True, mp_group=None, enable_xetla=False,
optimize_lm_head=False):
optimize_lm_head=False, act_order=False):
super().__init__(input_features, output_features, bias)
self.weight = FP4Params(self.weight.data,
requires_grad=False,
@ -603,6 +603,11 @@ class LowBitLinear(nn.Linear):
# since performance isn't impacted.
self.is_lm_head = self.in_len * self.out_len >= 32000 * 4096 and self.bias is None
self.low_memory_mode = self.is_lm_head
self.act_order = act_order
if act_order:
self.register_buffer(
"g_idx_map",
torch.tensor([i for i in range(self.in_len)], dtype=torch.int64))
def forward(self, x: torch.Tensor):
# empty cache before and after lm_head at first token when input > 1024
@ -640,6 +645,9 @@ class LowBitLinear(nn.Linear):
return torch.empty(new_shape, dtype=x.dtype, device=x.device)
x_2d = x.view(-1, x_shape[-1])
if self.act_order:
x_2d = x_2d[:, self.g_idx_map]
# x0 for weight
x0 = self.weight.data

View file

@ -243,8 +243,6 @@ class _BaseAutoModelClass:
if q_config["quant_method"] == "gptq":
invalidInputError(q_config["bits"] == 4,
"Only 4-bit gptq is supported in bigdl-llm.")
invalidInputError(q_config["desc_act"] is False,
"Only desc_act=False is supported in bigdl-llm.")
if load_in_low_bit is not None:
invalidInputError(load_in_low_bit == "asym_int4",
"You can only load gptq model as aysm_int4 low bit type.")
@ -448,6 +446,8 @@ class _BaseAutoModelClass:
offload_dir=None
)
else:
if quant_config is not None:
kwargs["quantization_config"] = quant_config
_load_pre()
try:
# To handle the input CUDA setting (such as 'device_map={"":0}'), ignore it