diff --git a/python/llm/example/CPU/HF-Transformers-AutoModels/Advanced-Quantizations/GPTQ/README.md b/python/llm/example/CPU/HF-Transformers-AutoModels/Advanced-Quantizations/GPTQ/README.md new file mode 100644 index 00000000..7b38158f --- /dev/null +++ b/python/llm/example/CPU/HF-Transformers-AutoModels/Advanced-Quantizations/GPTQ/README.md @@ -0,0 +1,73 @@ +# GPTQ +This example shows how to directly run 4-bit GPTQ models using BigDL-LLM on Intel CPU. For illustration purposes, we utilize the ["TheBloke/Llama-2-7B-GPTQ"](https://huggingface.co/TheBloke/Llama-2-7B-Chat-GPTQ) as a reference. + +## 0. Requirements +To run these examples with BigDL-LLM, we have some recommended requirements for your machine, please refer to [here](../README.md#recommended-requirements) for more information. + +## Example: Predict Tokens using `generate()` API +In the example [generate.py](./generate.py), we show a basic use case for a Llama2 model to predict the next N tokens using `generate()` API, with BigDL-LLM INT4 optimizations. +### 1. Install +We suggest using conda to manage environment: +```bash +conda create -n llm python=3.9 +conda activate llm + +pip install bigdl-llm[all] # install bigdl-llm with 'all' option +pip install transformers==4.34.0 +BUILD_CUDA_EXT=0 pip install git+https://github.com/PanQiWei/AutoGPTQ.git@1de9ab6 +pip install optimum==0.14.0 +``` + +### 2. Run +``` +python ./generate.py --repo-id-or-model-path REPO_ID_OR_MODEL_PATH --prompt PROMPT --n-predict N_PREDICT +``` + +Arguments info: +- `--repo-id-or-model-path REPO_ID_OR_MODEL_PATH`: argument defining the huggingface repo id for the Llama2-gptq model (e.g. `TheBloke/Llama-2-7B-GPTQ`) to be downloaded, or the path to the huggingface checkpoint folder. It is default to be `'TheBloke/Llama-2-7B-GPTQ'`. +- `--prompt PROMPT`: argument defining the prompt to be infered (with integrated prompt format for chat). It is default to be `'What is AI?'`. +- `--n-predict N_PREDICT`: argument defining the max number of tokens to predict. It is default to be `32`. + +> **Note**: When loading the model in 4-bit, BigDL-LLM converts linear layers in the model into INT4 format. In theory, a *X*B model saved in 16-bit will requires approximately 2*X* GB of memory for loading, and ~0.5*X* GB memory for further inference. +> +> Please select the appropriate size of the Llama2 model based on the capabilities of your machine. + +#### 2.1 Client +On client Windows machine, it is recommended to run directly with full utilization of all cores: +```powershell +python ./generate.py +``` + +#### 2.2 Server +For optimal performance on server, it is recommended to set several environment variables (refer to [here](../README.md#best-known-configuration-on-linux) for more information), and run the example with all the physical cores of a single socket. + +E.g. on Linux, +```bash +# set BigDL-Nano env variables +source bigdl-nano-init + +# e.g. for a server with 48 cores per socket +export OMP_NUM_THREADS=48 +numactl -C 0-47 -m 0 python ./generate.py +``` + +#### 2.3 Sample Output +#### [TheBloke/Llama-2-7B-GPTQ](https://huggingface.co/TheBloke/Llama-2-7B-Chat-GPTQ) +```log +Inference time: xxxx s +-------------------- Prompt -------------------- +### HUMAN: +What is AI? + +### RESPONSE: + +-------------------- Output -------------------- +### HUMAN: +What is AI? + +### RESPONSE: + +> AI is a branch of computer science that aims to create intelligent machines that think and act like humans. + +### HUMAN +``` \ No newline at end of file diff --git a/python/llm/example/CPU/HF-Transformers-AutoModels/Advanced-Quantizations/GPTQ/generate.py b/python/llm/example/CPU/HF-Transformers-AutoModels/Advanced-Quantizations/GPTQ/generate.py new file mode 100644 index 00000000..70ccef6d --- /dev/null +++ b/python/llm/example/CPU/HF-Transformers-AutoModels/Advanced-Quantizations/GPTQ/generate.py @@ -0,0 +1,72 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import torch +import time +import argparse + +from bigdl.llm.transformers import AutoModelForCausalLM +from transformers import LlamaTokenizer, GPTQConfig + +# you could tune the prompt based on your own model, +# here the prompt tuning refers to https://huggingface.co/georgesung/llama2_7b_chat_uncensored#prompt-style +LLAMA2_PROMPT_FORMAT = """### HUMAN: +{prompt} + +### RESPONSE: +""" + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for Llama2 model') + parser.add_argument('--repo-id-or-model-path', type=str, default="TheBloke/Llama-2-7B-GPTQ", + help='The huggingface repo id' + ', or the path to the huggingface checkpoint folder') + parser.add_argument('--prompt', type=str, default="What is AI?", + help='Prompt to infer') + parser.add_argument('--n-predict', type=int, default=32, + help='Max tokens to predict') + + args = parser.parse_args() + model_path = args.repo_id_or_model_path + + # Load model in 4 bit, + # which convert the relevant layers in the model into INT4 format + model = AutoModelForCausalLM.from_pretrained(model_path, + load_in_4bit=True, + torch_dtype=torch.float, + trust_remote_code=True,) + + # Load tokenizer + tokenizer = LlamaTokenizer.from_pretrained(model_path, trust_remote_code=True) + + # Generate predicted tokens + with torch.inference_mode(): + prompt = LLAMA2_PROMPT_FORMAT.format(prompt=args.prompt) + input_ids = tokenizer.encode(prompt, return_tensors="pt") + st = time.time() + # if your selected model is capable of utilizing previous key/value attentions + # to enhance decoding speed, but has `"use_cache": false` in its model config, + # it is important to set `use_cache=True` explicitly in the `generate` function + # to obtain optimal performance with BigDL-LLM INT4 optimizations + output = model.generate(input_ids, + max_new_tokens=args.n_predict) + end = time.time() + output_str = tokenizer.decode(output[0], skip_special_tokens=True) + print(f'Inference time: {end-st} s') + print('-'*20, 'Prompt', '-'*20) + print(prompt) + print('-'*20, 'Output', '-'*20) + print(output_str) diff --git a/python/llm/example/GPU/HF-Transformers-AutoModels/Advanced-Quantizations/GPTQ/README.md b/python/llm/example/GPU/HF-Transformers-AutoModels/Advanced-Quantizations/GPTQ/README.md new file mode 100644 index 00000000..280fab96 --- /dev/null +++ b/python/llm/example/GPU/HF-Transformers-AutoModels/Advanced-Quantizations/GPTQ/README.md @@ -0,0 +1,67 @@ +# GPTQ +This example shows how to directly run 4-bit GPTQ models using BigDL-LLM on Intel GPU. For illustration purposes, we utilize the ["TheBloke/Llama-2-7B-GPTQ"](https://huggingface.co/TheBloke/Llama-2-7B-Chat-GPTQ) as a reference. + +## 0. Requirements +To run these examples with BigDL-LLM, we have some recommended requirements for your machine, please refer to [here](../README.md#recommended-requirements) for more information. + +## Example: Predict Tokens using `generate()` API +In the example [generate.py](./generate.py), we show a basic use case for a Llama2 model to predict the next N tokens using `generate()` API, with BigDL-LLM INT4 optimizations. +### 1. Install +We suggest using conda to manage environment: +```bash +conda create -n llm python=3.9 +conda activate llm + +pip install --pre --upgrade bigdl-llm[xpu] -f https://developer.intel.com/ipex-whl-stable-xpu +pip install transformers==4.34.0 +BUILD_CUDA_EXT=0 pip install git+https://github.com/PanQiWei/AutoGPTQ.git@1de9ab6 +pip install optimum==0.14.0 +``` + +### 2. Configures OneAPI environment variables +```bash +source /opt/intel/oneapi/setvars.sh +``` + +### 3. Run + +For optimal performance on Arc, it is recommended to set several environment variables. + +```bash +export USE_XETLA=OFF +export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1 +``` + +``` +python ./generate.py --repo-id-or-model-path REPO_ID_OR_MODEL_PATH --prompt PROMPT --n-predict N_PREDICT +``` + +Arguments info: +- `--repo-id-or-model-path REPO_ID_OR_MODEL_PATH`: argument defining the huggingface repo id for the Llama2-gptq model (e.g. `TheBloke/Llama-2-7B-GPTQ`) to be downloaded, or the path to the huggingface checkpoint folder. It is default to be `'TheBloke/Llama-2-7B-GPTQ'`. +- `--prompt PROMPT`: argument defining the prompt to be infered (with integrated prompt format for chat). It is default to be `'What is AI?'`. +- `--n-predict N_PREDICT`: argument defining the max number of tokens to predict. It is default to be `32`. + +> **Note**: When loading the model in 4-bit, BigDL-LLM converts linear layers in the model into INT4 format. In theory, a *X*B model saved in 16-bit will requires approximately 2*X* GB of memory for loading, and ~0.5*X* GB memory for further inference. +> +> Please select the appropriate size of the Llama2 model based on the capabilities of your machine. + +#### 2.3 Sample Output +#### [TheBloke/Llama-2-7B-GPTQ](https://huggingface.co/TheBloke/Llama-2-7B-Chat-GPTQ) +```log +Inference time: xxxx s +-------------------- Prompt -------------------- +### HUMAN: +What is AI? + +### RESPONSE: + +-------------------- Output -------------------- +### HUMAN: +What is AI? + +### RESPONSE: + +> AI is a branch of computer science that aims to create intelligent machines that think and act like humans. + +### HUMAN +``` \ No newline at end of file diff --git a/python/llm/example/GPU/HF-Transformers-AutoModels/Advanced-Quantizations/GPTQ/generate.py b/python/llm/example/GPU/HF-Transformers-AutoModels/Advanced-Quantizations/GPTQ/generate.py new file mode 100644 index 00000000..5d77801d --- /dev/null +++ b/python/llm/example/GPU/HF-Transformers-AutoModels/Advanced-Quantizations/GPTQ/generate.py @@ -0,0 +1,72 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import torch +import time +import argparse +import intel_extension_for_pytorch as ipex +from bigdl.llm.transformers import AutoModelForCausalLM +from transformers import LlamaTokenizer, GPTQConfig + +# you could tune the prompt based on your own model, +# here the prompt tuning refers to https://huggingface.co/georgesung/llama2_7b_chat_uncensored#prompt-style +LLAMA2_PROMPT_FORMAT = """### HUMAN: +{prompt} + +### RESPONSE: +""" + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for Llama2 model') + parser.add_argument('--repo-id-or-model-path', type=str, default="TheBloke/Llama-2-7B-GPTQ", + help='The huggingface repo id' + ', or the path to the huggingface checkpoint folder') + parser.add_argument('--prompt', type=str, default="What is AI?", + help='Prompt to infer') + parser.add_argument('--n-predict', type=int, default=32, + help='Max tokens to predict') + + args = parser.parse_args() + model_path = args.repo_id_or_model_path + + # Load model in 4 bit, + # which convert the relevant layers in the model into INT4 format + model = AutoModelForCausalLM.from_pretrained(model_path, + load_in_4bit=True, + torch_dtype=torch.float, + trust_remote_code=True,).to("xpu") + + # Load tokenizer + tokenizer = LlamaTokenizer.from_pretrained(model_path, trust_remote_code=True) + + # Generate predicted tokens + with torch.inference_mode(): + prompt = LLAMA2_PROMPT_FORMAT.format(prompt=args.prompt) + input_ids = tokenizer.encode(prompt, return_tensors="pt").to("xpu") + st = time.time() + # if your selected model is capable of utilizing previous key/value attentions + # to enhance decoding speed, but has `"use_cache": false` in its model config, + # it is important to set `use_cache=True` explicitly in the `generate` function + # to obtain optimal performance with BigDL-LLM INT4 optimizations + output = model.generate(input_ids, + max_new_tokens=args.n_predict) + end = time.time() + output_str = tokenizer.decode(output[0], skip_special_tokens=True) + print(f'Inference time: {end-st} s') + print('-'*20, 'Prompt', '-'*20) + print(prompt) + print('-'*20, 'Output', '-'*20) + print(output_str) diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index 6e4ccae5..a1a9f7f1 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -41,22 +41,38 @@ import torch.nn as nn from accelerate import init_empty_weights import warnings import transformers -import importlib +import importlib.util from bigdl.llm.ggml.quantize import ggml_tensor_qtype from .utils import logger +from typing import Union +import numpy as np +from bigdl.llm.utils.common import invalidInputError + + +def is_auto_gptq_available(): + return importlib.util.find_spec("auto_gptq") is not None def is_deepspeed_available(): return importlib.util.find_spec("deepspeed") is not None +if is_auto_gptq_available(): + from auto_gptq.utils.peft_utils import QuantLinearCuda, QuantLinearCudaOld + + def is_linear_module(module): in_features = None out_features = None mp_group = None - if isinstance(module, nn.Linear): + if is_auto_gptq_available() and isinstance(module, QuantLinearCudaOld): + in_features = module.infeatures + out_features = module.outfeatures + mp_group = None + result = True + elif isinstance(module, nn.Linear): in_features = module.in_features out_features = module.out_features mp_group = None @@ -82,6 +98,61 @@ def is_linear_module(module): return result, (in_features, out_features, mp_group) +from bigdl.llm.transformers.low_bit_linear import get_ggml_qk_size +Q4_1 = get_ggml_qk_size("asym_int4") + + +def convert_gptq(module): + + scales = module.scales + + zeros = torch.bitwise_right_shift( + torch.unsqueeze(module.qzeros, 2).expand(-1, -1, 32 // module.bits), + module.wf.unsqueeze(0)).to(torch.int16 if module.bits == 8 else torch.int8) + zeros = torch.bitwise_and(zeros, (2 ** module.bits) - 1) + + zeros = zeros + 1 + zeros = zeros.reshape(scales.shape) + + weight = torch.bitwise_right_shift( + torch.unsqueeze(module.qweight, 1).expand(-1, 32 // module.bits, -1), + module.wf.unsqueeze(-1)).to(torch.int8) + weight = torch.bitwise_and(weight, (2 ** module.bits) - 1) + weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2]) + + # convert weight to ggml format + weight = weight.reshape(weight.shape[0]//module.group_size, module.group_size, weight.shape[1]) + weight = weight.permute(2, 0, 1).reshape(weight.shape[2], -1, 2, Q4_1//2) + weight = weight.transpose(2, 3) + weight = torch.bitwise_left_shift(weight, + torch.tensor([0, 4], dtype=torch.int8).reshape(1, 1, 1, 2)) + weight = torch.bitwise_or(weight[:, :, :, 0], weight[:, :, :, 1]).contiguous() + + # convert zeros to ggml format + zeros = zeros.reshape(-1, 1, zeros.shape[1]).permute(2, 0, 1)\ + .unsqueeze(2)\ + .expand(-1, -1, module.group_size//Q4_1, -1)\ + .reshape(zeros.shape[1], -1, 1)\ + .contiguous().to(torch.float16) + + # convert scales to ggml format + scales = scales.reshape(-1, 1, scales.shape[1]).permute(2, 0, 1)\ + .unsqueeze(2)\ + .expand(-1, -1, module.group_size//Q4_1, -1)\ + .reshape(scales.shape[-1], -1, 1)\ + .contiguous().to(torch.float16) + + m = -(zeros * scales) + d = scales + + ggml_weight = torch.cat([d.view(torch.uint8), + m.view(torch.uint8), + weight.view(torch.uint8)], dim=-1) + ggml_weight = ggml_weight.reshape([-1]) + + return ggml_weight + + def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, current_key_name=None, convert_shape_only=False, replace_embedding=False): @@ -100,7 +171,30 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, in_features, out_features, mp_group = linear_args with init_empty_weights(): new_linear = None - if qtype != ggml_tensor_qtype["fp16"]: + if is_auto_gptq_available() and isinstance(module, QuantLinearCudaOld): + has_bias = module.bias is not None and module.bias.abs().sum() != 0 + new_linear = LowBitLinear( + in_features, + out_features, + qtype=qtype, + bias=has_bias, + mp_group=mp_group, + ) + device_type = module.qweight.data.device.type + invalidInputError(device_type != "meta", + "converting from meta device is not supported") + # Copy the weights + paramsLowBit = FP4Params(data=convert_gptq(module), + requires_grad=False, + quantized=True, + _shape=(out_features, in_features), + convert_shape_only=convert_shape_only, + qtype=qtype).to(device_type) + new_linear._parameters['weight'] = paramsLowBit + if has_bias: + new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\ + .to(device_type) + elif qtype != ggml_tensor_qtype["fp16"]: new_linear = LowBitLinear( in_features, out_features, @@ -118,6 +212,9 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, convert_shape_only=convert_shape_only, qtype=qtype).to(device_type) new_linear._parameters['weight'] = paramsLowBit + if module.bias is not None: + new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\ + .to(device_type) else: # only support two size now # may generalize to other sizes @@ -137,13 +234,12 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, trans_weight = module.weight.data.reshape(m//16, 16, n) trans_weight = trans_weight.transpose(1, 2).contiguous() new_linear._parameters['weight'] = nn.Parameter(trans_weight) + if module.bias is not None: + new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\ + .to(device_type) # fp16 may generalize to other sizes later if new_linear is not None: - if module.bias is not None: - new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\ - .to(device_type) - model._modules[name] = new_linear has_been_replaced = True # Force requires grad to False to avoid unexpected errors @@ -223,7 +319,8 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True, "an issue on github if you think this is a bug." ) elif device == "cpu": - model.to(torch.float32) + if not (getattr(model, "quantization_method", None) == "gptq"): + model.to(torch.float32) elif device == "meta": # Do nothing here for weights are empty. pass diff --git a/python/llm/src/bigdl/llm/transformers/low_bit_linear.py b/python/llm/src/bigdl/llm/transformers/low_bit_linear.py index f232558f..00deef0d 100644 --- a/python/llm/src/bigdl/llm/transformers/low_bit_linear.py +++ b/python/llm/src/bigdl/llm/transformers/low_bit_linear.py @@ -70,6 +70,10 @@ MOFQ4 = ggml_tensor_qtype["mixed_fp4"] MOFQ8 = ggml_tensor_qtype["mixed_fp8"] +def get_ggml_qk_size(qtype: str): + return ggml.ggml_qk_size(ggml_tensor_qtype[qtype]) + + def ggml_convert_qtype(tensor: torch.Tensor, qtype: int, device=None, convert_shape_only=False): QK = ggml.ggml_qk_size(qtype) diff --git a/python/llm/src/bigdl/llm/transformers/model.py b/python/llm/src/bigdl/llm/transformers/model.py index 93a42876..74aae079 100644 --- a/python/llm/src/bigdl/llm/transformers/model.py +++ b/python/llm/src/bigdl/llm/transformers/model.py @@ -22,6 +22,7 @@ from .utils import extract_local_archive_file, \ from bigdl.llm.ggml.quantize import ggml_tensor_qtype from bigdl.llm.utils.common import invalidInputError import torch +import warnings import copy from .utils import logger @@ -30,6 +31,10 @@ def save_low_bit(self, *args, **kwargs): invalidInputError(self.config.to_dict().get("bigdl_transformers_low_bit", False), f"Detected this model is not a low-bit model, please use from_pretrained's" f" load_in_4bit or load_in_low_bit parameter to load a 4-bit model first.") + if hasattr(self.config, "quantization_config"): + delattr(self.config, "quantization_config") + delattr(self.config, "_pre_quantization_dtype") + self.to('cpu') self.save_pretrained(*args, **kwargs) import json @@ -57,7 +62,9 @@ class _BaseAutoModelClass: Three new arguments are added to extend Hugging Face's from_pretrained method as follows: - :param load_in_4bit: boolean value, True means load linear's weight to symmetric int 4. + :param load_in_4bit: boolean value, True means loading linear's weight to symmetric int 4 if + the model is a regular fp16/bf16/fp32 model, and to asymmetric int 4 + if the model is GPTQ model. Default to be False. :param load_in_low_bit: str value, options are sym_int4, asym_int4, sym_int5, asym_int5 , sym_int8, nf3, nf4, fp4, fp8 or fp16. sym_int4 means symmetric @@ -70,7 +77,6 @@ class _BaseAutoModelClass: conducting model optimizations. Default to be None. :param replace_embedding: Whether to replace the Embedding layer, may need to set it to `True` when running BigDL-LLM on GPU on Windows. Default to be `False`. - :return: a model instance """ pretrained_model_name_or_path = kwargs.get("pretrained_model_name_or_path", None) \ @@ -87,8 +93,37 @@ class _BaseAutoModelClass: load_in_4bit = kwargs.pop("load_in_4bit", False) load_in_low_bit = kwargs.pop("load_in_low_bit", None) optimize_model = kwargs.pop("optimize_model", True) + user_quantization_config = kwargs.pop("quantization_config", None) if load_in_4bit or load_in_low_bit: + + if config_dict.get("quantization_config", None) is not None: + from bigdl.llm.transformers.low_bit_linear import get_ggml_qk_size + q_config = config_dict["quantization_config"] + if q_config["quant_method"] == "gptq": + invalidInputError(q_config["bits"] == 4, + "Only 4-bit gptq is supported in bigdl-llm.") + invalidInputError(q_config["desc_act"] is False, + "Only desc_act=False is supported in bigdl-llm.") + if load_in_low_bit is not None: + invalidInputError(load_in_low_bit == "asym_int4", + "You can only load gptq model as aysm_int4 low bit type.") + + load_in_low_bit = "asym_int4" + if int(q_config["group_size"]) % get_ggml_qk_size(load_in_low_bit) != 0: + invalidInputError(False, + (f"group_size must be divisible by " + f"{get_ggml_qk_size(load_in_low_bit)}.")) + if user_quantization_config is not None: + invalidInputError(user_quantization_config.bits == 4, + "Only 4-bit gptq is supported in bigdl-llm.") + invalidInputError(user_quantization_config.use_exllama is False, + "Only use_exllama=False is supported in bigdl-llm.") + else: + from transformers import GPTQConfig + user_quantization_config = GPTQConfig(bits=4, use_exllama=False) + kwargs["quantization_config"] = user_quantization_config + # load int x-bit kwargs["low_cpu_mem_usage"] = True # set default torch_dtype='auto'