Support the desc_act feature in GPTQ model (#10851)
* support act_order * update versions * fix style * fix bug * clean up
This commit is contained in:
parent
dc27b3bc35
commit
1ce8d7bcd9
5 changed files with 56 additions and 20 deletions
|
|
@ -13,9 +13,9 @@ conda create -n llm python=3.11
|
||||||
conda activate llm
|
conda activate llm
|
||||||
# below command will install intel_extension_for_pytorch==2.1.10+xpu as default
|
# below command will install intel_extension_for_pytorch==2.1.10+xpu as default
|
||||||
pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
|
pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
|
||||||
pip install transformers==4.34.0
|
pip install transformers==4.37.0
|
||||||
BUILD_CUDA_EXT=0 pip install git+https://github.com/PanQiWei/AutoGPTQ.git@1de9ab6
|
pip install auto_gptq==0.7.1
|
||||||
pip install optimum==0.14.0
|
pip install optimum==1.14.0
|
||||||
```
|
```
|
||||||
|
|
||||||
### 2. Configures OneAPI environment variables
|
### 2. Configures OneAPI environment variables
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,7 @@ import torch
|
||||||
import time
|
import time
|
||||||
import argparse
|
import argparse
|
||||||
from ipex_llm.transformers import AutoModelForCausalLM
|
from ipex_llm.transformers import AutoModelForCausalLM
|
||||||
from transformers import LlamaTokenizer, GPTQConfig
|
from transformers import AutoTokenizer, GPTQConfig
|
||||||
|
|
||||||
# you could tune the prompt based on your own model,
|
# you could tune the prompt based on your own model,
|
||||||
# here the prompt tuning refers to https://huggingface.co/georgesung/llama2_7b_chat_uncensored#prompt-style
|
# here the prompt tuning refers to https://huggingface.co/georgesung/llama2_7b_chat_uncensored#prompt-style
|
||||||
|
|
@ -30,7 +30,7 @@ LLAMA2_PROMPT_FORMAT = """### HUMAN:
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for Llama2 model')
|
parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for Llama2 model')
|
||||||
parser.add_argument('--repo-id-or-model-path', type=str, default="TheBloke/Llama-2-7B-GPTQ",
|
parser.add_argument('--repo-id-or-model-path', type=str, default="TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ",
|
||||||
help='The huggingface repo id'
|
help='The huggingface repo id'
|
||||||
', or the path to the huggingface checkpoint folder')
|
', or the path to the huggingface checkpoint folder')
|
||||||
parser.add_argument('--prompt', type=str, default="What is AI?",
|
parser.add_argument('--prompt', type=str, default="What is AI?",
|
||||||
|
|
@ -47,9 +47,10 @@ if __name__ == '__main__':
|
||||||
load_in_4bit=True,
|
load_in_4bit=True,
|
||||||
torch_dtype=torch.float,
|
torch_dtype=torch.float,
|
||||||
trust_remote_code=True,).to("xpu")
|
trust_remote_code=True,).to("xpu")
|
||||||
|
|
||||||
|
print(model)
|
||||||
# Load tokenizer
|
# Load tokenizer
|
||||||
tokenizer = LlamaTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||||
|
|
||||||
# Generate predicted tokens
|
# Generate predicted tokens
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
|
|
|
||||||
|
|
@ -99,6 +99,11 @@ def is_lm_head(name, model_config, out_features):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def is_gptq_linear(module):
|
||||||
|
return is_auto_gptq_available() and \
|
||||||
|
(isinstance(module, QuantLinearCuda) or isinstance(module, QuantLinearCudaOld))
|
||||||
|
|
||||||
|
|
||||||
def is_linear_module(module):
|
def is_linear_module(module):
|
||||||
|
|
||||||
in_features = None
|
in_features = None
|
||||||
|
|
@ -122,7 +127,7 @@ def is_linear_module(module):
|
||||||
mp_group = None
|
mp_group = None
|
||||||
else:
|
else:
|
||||||
result = False
|
result = False
|
||||||
elif is_auto_gptq_available() and isinstance(module, QuantLinearCudaOld):
|
elif is_gptq_linear(module):
|
||||||
in_features = module.infeatures
|
in_features = module.infeatures
|
||||||
out_features = module.outfeatures
|
out_features = module.outfeatures
|
||||||
mp_group = None
|
mp_group = None
|
||||||
|
|
@ -153,7 +158,7 @@ def is_linear_module(module):
|
||||||
return result, (in_features, out_features, mp_group)
|
return result, (in_features, out_features, mp_group)
|
||||||
|
|
||||||
|
|
||||||
def convert_gptq(module, awq=False, llm_awq=False):
|
def convert_gptq(module, awq=False, llm_awq=False, act_order=False):
|
||||||
from ipex_llm.transformers.low_bit_linear import get_block_size
|
from ipex_llm.transformers.low_bit_linear import get_block_size
|
||||||
Q4_1 = get_block_size("asym_int4")
|
Q4_1 = get_block_size("asym_int4")
|
||||||
|
|
||||||
|
|
@ -164,6 +169,8 @@ def convert_gptq(module, awq=False, llm_awq=False):
|
||||||
module.wf.unsqueeze(0)).to(torch.int16 if module.bits == 8 else torch.int8)
|
module.wf.unsqueeze(0)).to(torch.int16 if module.bits == 8 else torch.int8)
|
||||||
zeros = torch.bitwise_and(zeros, (2 ** module.bits) - 1)
|
zeros = torch.bitwise_and(zeros, (2 ** module.bits) - 1)
|
||||||
|
|
||||||
|
g_id_map = None
|
||||||
|
|
||||||
if not awq:
|
if not awq:
|
||||||
zeros = zeros + 1
|
zeros = zeros + 1
|
||||||
zeros = zeros.reshape(scales.shape)
|
zeros = zeros.reshape(scales.shape)
|
||||||
|
|
@ -183,6 +190,12 @@ def convert_gptq(module, awq=False, llm_awq=False):
|
||||||
weight = torch.bitwise_and(weight, (2 ** module.bits) - 1)
|
weight = torch.bitwise_and(weight, (2 ** module.bits) - 1)
|
||||||
weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2])
|
weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2])
|
||||||
|
|
||||||
|
if act_order:
|
||||||
|
invalidInputError(module.g_idx.shape[0] == weight.shape[0],
|
||||||
|
"g_idx and weight shape mismatch")
|
||||||
|
_, g_id_map = torch.sort(module.g_idx)
|
||||||
|
weight = weight[g_id_map, :]
|
||||||
|
|
||||||
# convert weight to ggml format
|
# convert weight to ggml format
|
||||||
weight = weight.reshape(weight.shape[0]//module.group_size, module.group_size, weight.shape[1])
|
weight = weight.reshape(weight.shape[0]//module.group_size, module.group_size, weight.shape[1])
|
||||||
weight = weight.permute(2, 0, 1).reshape(weight.shape[2], -1, 2, Q4_1//2)
|
weight = weight.permute(2, 0, 1).reshape(weight.shape[2], -1, 2, Q4_1//2)
|
||||||
|
|
@ -219,7 +232,7 @@ def convert_gptq(module, awq=False, llm_awq=False):
|
||||||
weight.view(torch.uint8)], dim=-1)
|
weight.view(torch.uint8)], dim=-1)
|
||||||
ggml_weight = ggml_weight.reshape([-1])
|
ggml_weight = ggml_weight.reshape([-1])
|
||||||
|
|
||||||
return ggml_weight
|
return ggml_weight, g_id_map
|
||||||
|
|
||||||
|
|
||||||
def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
||||||
|
|
@ -228,7 +241,9 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
||||||
imatrix_data=None, embedding_qtype=None,
|
imatrix_data=None, embedding_qtype=None,
|
||||||
model_config=None, torch_dtype=torch.float32,
|
model_config=None, torch_dtype=torch.float32,
|
||||||
enable_xetla=False,
|
enable_xetla=False,
|
||||||
mixed_precision=False):
|
mixed_precision=False,
|
||||||
|
act_order=False,
|
||||||
|
):
|
||||||
from ipex_llm.transformers.low_bit_linear import LowBitLinear, FP4Params, \
|
from ipex_llm.transformers.low_bit_linear import LowBitLinear, FP4Params, \
|
||||||
FP16Linear, BF16Linear
|
FP16Linear, BF16Linear
|
||||||
from ipex_llm.transformers.embedding import LLMEmbedding, LowBitEmbedding
|
from ipex_llm.transformers.embedding import LLMEmbedding, LowBitEmbedding
|
||||||
|
|
@ -252,7 +267,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
||||||
optimize_lm_head = True
|
optimize_lm_head = True
|
||||||
with init_empty_weights():
|
with init_empty_weights():
|
||||||
new_linear = None
|
new_linear = None
|
||||||
is_gptq = is_auto_gptq_available() and isinstance(module, QuantLinearCudaOld)
|
is_gptq = is_gptq_linear(module)
|
||||||
is_awq = is_auto_awq_available() and isinstance(module, WQLinear_GEMM)
|
is_awq = is_auto_awq_available() and isinstance(module, WQLinear_GEMM)
|
||||||
is_llm_awq = is_awq and module.backend == AwqBackendPackingMethod.LLMAWQ
|
is_llm_awq = is_awq and module.backend == AwqBackendPackingMethod.LLMAWQ
|
||||||
if is_gptq or is_awq:
|
if is_gptq or is_awq:
|
||||||
|
|
@ -264,14 +279,20 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
||||||
bias=has_bias,
|
bias=has_bias,
|
||||||
mp_group=mp_group,
|
mp_group=mp_group,
|
||||||
enable_xetla=enable_xetla,
|
enable_xetla=enable_xetla,
|
||||||
optimize_lm_head=optimize_lm_head
|
optimize_lm_head=optimize_lm_head,
|
||||||
|
act_order=act_order,
|
||||||
)
|
)
|
||||||
device = module.qweight.data.device
|
device = module.qweight.data.device
|
||||||
invalidInputError(device.type != "meta",
|
invalidInputError(device.type != "meta",
|
||||||
"converting from meta device is not supported")
|
"converting from meta device is not supported")
|
||||||
|
weight, g_idx_map = convert_gptq(module,
|
||||||
|
awq=is_awq,
|
||||||
|
llm_awq=is_llm_awq,
|
||||||
|
act_order=act_order)
|
||||||
|
if act_order:
|
||||||
|
new_linear.g_idx_map = g_idx_map
|
||||||
# Copy the weights
|
# Copy the weights
|
||||||
paramsLowBit = FP4Params(data=convert_gptq(module, awq=is_awq,
|
paramsLowBit = FP4Params(data=weight,
|
||||||
llm_awq=is_llm_awq),
|
|
||||||
requires_grad=False,
|
requires_grad=False,
|
||||||
quantized=True,
|
quantized=True,
|
||||||
_shape=(out_features, in_features),
|
_shape=(out_features, in_features),
|
||||||
|
|
@ -422,7 +443,8 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
enable_xetla=enable_xetla,
|
enable_xetla=enable_xetla,
|
||||||
mixed_precision=mixed_precision
|
mixed_precision=mixed_precision,
|
||||||
|
act_order=act_order,
|
||||||
)
|
)
|
||||||
has_been_replaced = _flag or has_been_replaced
|
has_been_replaced = _flag or has_been_replaced
|
||||||
return model, has_been_replaced
|
return model, has_been_replaced
|
||||||
|
|
@ -464,7 +486,7 @@ def replace_with_low_bit_linear_for_module(model, qtype, module_name=None,
|
||||||
in_features, out_features, mp_group = linear_args
|
in_features, out_features, mp_group = linear_args
|
||||||
with init_empty_weights():
|
with init_empty_weights():
|
||||||
new_linear = None
|
new_linear = None
|
||||||
is_gptq = is_auto_gptq_available() and isinstance(module, QuantLinearCudaOld)
|
is_gptq = is_gptq_linear(module)
|
||||||
is_awq = is_auto_awq_available() and isinstance(module, WQLinear_GEMM)
|
is_awq = is_auto_awq_available() and isinstance(module, WQLinear_GEMM)
|
||||||
is_llm_awq = is_awq and module.backend == AwqBackendPackingMethod.LLMAWQ
|
is_llm_awq = is_awq and module.backend == AwqBackendPackingMethod.LLMAWQ
|
||||||
if is_gptq or is_awq:
|
if is_gptq or is_awq:
|
||||||
|
|
@ -721,6 +743,10 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
|
||||||
if optimize_model:
|
if optimize_model:
|
||||||
model = _optimize_pre(model)
|
model = _optimize_pre(model)
|
||||||
|
|
||||||
|
act_order = False
|
||||||
|
if getattr(model, "quantization_method", None) == "gptq":
|
||||||
|
act_order = model.config.quantization_config.desc_act
|
||||||
|
|
||||||
# mixed quantization needs model_config to choose custom quantization strategy
|
# mixed quantization needs model_config to choose custom quantization strategy
|
||||||
model, has_been_replaced = _replace_with_low_bit_linear(
|
model, has_been_replaced = _replace_with_low_bit_linear(
|
||||||
model, qtype, modules_to_not_convert,
|
model, qtype, modules_to_not_convert,
|
||||||
|
|
@ -731,6 +757,7 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
enable_xetla=enable_xetla,
|
enable_xetla=enable_xetla,
|
||||||
mixed_precision=mixed_precision,
|
mixed_precision=mixed_precision,
|
||||||
|
act_order=act_order,
|
||||||
)
|
)
|
||||||
if not has_been_replaced:
|
if not has_been_replaced:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
|
|
|
||||||
|
|
@ -579,7 +579,7 @@ class MatMulLowBitCPU(torch.autograd.Function):
|
||||||
class LowBitLinear(nn.Linear):
|
class LowBitLinear(nn.Linear):
|
||||||
def __init__(self, input_features, output_features, qtype, bias=True,
|
def __init__(self, input_features, output_features, qtype, bias=True,
|
||||||
conver_to_half=True, mp_group=None, enable_xetla=False,
|
conver_to_half=True, mp_group=None, enable_xetla=False,
|
||||||
optimize_lm_head=False):
|
optimize_lm_head=False, act_order=False):
|
||||||
super().__init__(input_features, output_features, bias)
|
super().__init__(input_features, output_features, bias)
|
||||||
self.weight = FP4Params(self.weight.data,
|
self.weight = FP4Params(self.weight.data,
|
||||||
requires_grad=False,
|
requires_grad=False,
|
||||||
|
|
@ -603,6 +603,11 @@ class LowBitLinear(nn.Linear):
|
||||||
# since performance isn't impacted.
|
# since performance isn't impacted.
|
||||||
self.is_lm_head = self.in_len * self.out_len >= 32000 * 4096 and self.bias is None
|
self.is_lm_head = self.in_len * self.out_len >= 32000 * 4096 and self.bias is None
|
||||||
self.low_memory_mode = self.is_lm_head
|
self.low_memory_mode = self.is_lm_head
|
||||||
|
self.act_order = act_order
|
||||||
|
if act_order:
|
||||||
|
self.register_buffer(
|
||||||
|
"g_idx_map",
|
||||||
|
torch.tensor([i for i in range(self.in_len)], dtype=torch.int64))
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
def forward(self, x: torch.Tensor):
|
||||||
# empty cache before and after lm_head at first token when input > 1024
|
# empty cache before and after lm_head at first token when input > 1024
|
||||||
|
|
@ -640,6 +645,9 @@ class LowBitLinear(nn.Linear):
|
||||||
return torch.empty(new_shape, dtype=x.dtype, device=x.device)
|
return torch.empty(new_shape, dtype=x.dtype, device=x.device)
|
||||||
|
|
||||||
x_2d = x.view(-1, x_shape[-1])
|
x_2d = x.view(-1, x_shape[-1])
|
||||||
|
|
||||||
|
if self.act_order:
|
||||||
|
x_2d = x_2d[:, self.g_idx_map]
|
||||||
# x0 for weight
|
# x0 for weight
|
||||||
x0 = self.weight.data
|
x0 = self.weight.data
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -243,8 +243,6 @@ class _BaseAutoModelClass:
|
||||||
if q_config["quant_method"] == "gptq":
|
if q_config["quant_method"] == "gptq":
|
||||||
invalidInputError(q_config["bits"] == 4,
|
invalidInputError(q_config["bits"] == 4,
|
||||||
"Only 4-bit gptq is supported in bigdl-llm.")
|
"Only 4-bit gptq is supported in bigdl-llm.")
|
||||||
invalidInputError(q_config["desc_act"] is False,
|
|
||||||
"Only desc_act=False is supported in bigdl-llm.")
|
|
||||||
if load_in_low_bit is not None:
|
if load_in_low_bit is not None:
|
||||||
invalidInputError(load_in_low_bit == "asym_int4",
|
invalidInputError(load_in_low_bit == "asym_int4",
|
||||||
"You can only load gptq model as aysm_int4 low bit type.")
|
"You can only load gptq model as aysm_int4 low bit type.")
|
||||||
|
|
@ -448,6 +446,8 @@ class _BaseAutoModelClass:
|
||||||
offload_dir=None
|
offload_dir=None
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
if quant_config is not None:
|
||||||
|
kwargs["quantization_config"] = quant_config
|
||||||
_load_pre()
|
_load_pre()
|
||||||
try:
|
try:
|
||||||
# To handle the input CUDA setting (such as 'device_map={"":0}'), ignore it
|
# To handle the input CUDA setting (such as 'device_map={"":0}'), ignore it
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue