LLM: add GGUF-IQ2 examples (#10207)

* add iq2 examples

* small fix

* meet code review

* fix

* meet review

* small fix
This commit is contained in:
Ruonan Wang 2024-02-22 14:18:45 +08:00 committed by GitHub
parent 21de2613ce
commit 5e1fee5e05
6 changed files with 174 additions and 9 deletions

View file

@ -0,0 +1,81 @@
# GGUF-IQ2
This example shows how to run INT2 models using the IQ2 mechanism (first implemented by llama.cpp) in BigDL-LLM on Intel GPU.
## Verified Models
- [Llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b-hf), using [llama-v2-7b.imatrix](https://huggingface.co/datasets/ikawrakow/imatrix-from-wiki-train/resolve/main/llama-v2-7b.imatrix)
- [Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf), using [llama-v2-7b.imatrix](https://huggingface.co/datasets/ikawrakow/imatrix-from-wiki-train/resolve/main/llama-v2-7b.imatrix)
- [Mistral-7B-Instruct-v0.2](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2), using [mistral-7b-instruct-v0.2.imatrix](https://huggingface.co/datasets/ikawrakow/imatrix-from-wiki-train/resolve/main/mistral-7b-instruct-v0.2.imatrix)
- [Mixtral-8x7B-v0.1](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1), using [mixtral-8x7b.imatrix](https://huggingface.co/datasets/ikawrakow/imatrix-from-wiki-train/resolve/main/mixtral-8x7b.imatrix)
- [Mixtral-8x7B-Instruct-v0.1](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1), using [mixtral-8x7b-instruct-v0.1.imatrix](https://huggingface.co/datasets/ikawrakow/imatrix-from-wiki-train/resolve/main/mixtral-8x7b-instruct-v0.1.imatrix)
## Requirements
To run these examples with BigDL-LLM, we have some recommended requirements for your machine, please refer to [here](../../../README.md#requirements) for more information.
## Example: Predict Tokens using `generate()` API
In the example [generate.py](./generate.py), we show a basic use case for a GGUF-IQ2 model to predict the next N tokens using `generate()` API, with BigDL-LLM optimizations.
### 1. Install
We suggest using conda to manage environment:
```bash
conda create -n llm python=3.9
conda activate llm
# below command will install intel_extension_for_pytorch==2.1.10+xpu as default
pip install --pre --upgrade bigdl-llm[xpu] -f https://developer.intel.com/ipex-whl-stable-xpu
pip install transformers==4.35.0
```
**Note: For Mixtral model, please use transformers 4.36.0:**
```bash
pip install transformers==4.36.0
```
### 2. Configures OneAPI environment variables
```bash
source /opt/intel/oneapi/setvars.sh
```
### 3. Run
For optimal performance on Arc, it is recommended to set several environment variables.
```bash
export USE_XETLA=OFF
export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1
```
```
python ./generate.py --repo-id-or-model-path REPO_ID_OR_MODEL_PATH --prompt PROMPT --n-predict N_PREDICT
```
Arguments info:
- `--repo-id-or-model-path REPO_ID_OR_MODEL_PATH`: argument defining the huggingface repo id for the model (e.g. `meta-llama/Llama-2-7b-chat-hf` and `meta-llama/Llama-2-13b-chat-hf`) to be downloaded, or the path to the huggingface checkpoint folder. It is default to be `'meta-llama/Llama-2-7b-chat-hf'`.
- `--prompt PROMPT`: argument defining the prompt to be infered (with integrated prompt format for chat). It is default to be `'What is AI?'`.
- `--n-predict N_PREDICT`: argument defining the max number of tokens to predict. It is default to be `32`.
#### 2.3 Sample Output
#### [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf)
```log
Inference time: xxxx s
-------------------- Prompt --------------------
### HUMAN:
What is AI?
### RESPONSE:
-------------------- Output --------------------
### HUMAN:
What is AI?
### RESPONSE:
Artificial intelligence (AI) refers to the ability of machines to perform tasks that would typically require human intelligence, such as learning, problem-solving
```

View file

@ -0,0 +1,83 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import torch
import time
import argparse
from bigdl.llm.transformers import AutoModelForCausalLM
from transformers import AutoTokenizer
import warnings
# you could tune the prompt based on your own model,
# here the prompt tuning refers to https://huggingface.co/georgesung/llama2_7b_chat_uncensored#prompt-style
PROMPT_FORMAT = """### HUMAN:
{prompt}
### RESPONSE:
"""
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for LLM model')
parser.add_argument('--repo-id-or-model-path', type=str, default="meta-llama/Llama-2-7b-chat-hf",
help='The huggingface repo id'
', or the path to the huggingface checkpoint folder')
parser.add_argument('--prompt', type=str, default="What is AI?",
help='Prompt to infer')
parser.add_argument('--n-predict', type=int, default=32,
help='Max tokens to predict')
args = parser.parse_args()
model_path = args.repo_id_or_model_path
warnings.warn("iq2 quantization may need several minutes, please wait a moment, "
"or have a cup of coffee now : )")
# Load model in 2 bit,
# which convert the relevant layers in the model into gguf_iq2_xxs format.
# GGUF-IQ2 quantization needs imatrix file to assist in quantization
# and improve generation quality, and different model may need different
# imtraix file, you can find and download imatrix file from
# https://huggingface.co/datasets/ikawrakow/imatrix-from-wiki-train/tree/main.
model = AutoModelForCausalLM.from_pretrained(model_path,
load_in_low_bit='gguf_iq2_xxs',
trust_remote_code=True,
imatrix='llama-v2-7b.imatrix').to("xpu")
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
# Generate predicted tokens
with torch.inference_mode():
prompt = PROMPT_FORMAT.format(prompt=args.prompt)
input_ids = tokenizer.encode(prompt, return_tensors="pt").to("xpu")
# ipex model needs a warmup, then inference time can be accurate
output = model.generate(input_ids,
max_new_tokens=args.n_predict)
st = time.time()
# if your selected model is capable of utilizing previous key/value attentions
# to enhance decoding speed, but has `"use_cache": false` in its model config,
# it is important to set `use_cache=True` explicitly in the `generate` function
# to obtain optimal performance with BigDL-LLM Low Bit optimizations
output = model.generate(input_ids,
max_new_tokens=args.n_predict,
repetition_penalty=1.1)
end = time.time()
output_str = tokenizer.decode(output[0], skip_special_tokens=True)
print(f'Inference time: {end-st} s')
print('-'*20, 'Prompt', '-'*20)
print(prompt)
print('-'*20, 'Output', '-'*20)
print(output_str)

View file

@ -40,8 +40,8 @@ ggml_tensor_qtype = {"sym_int4": 2, # q4_0 in ggml
"fp8_e5m2": 19, # fp8 in e5m2 format "fp8_e5m2": 19, # fp8 in e5m2 format
"fp8": 19, # fp8 in e5m2 format "fp8": 19, # fp8 in e5m2 format
"bf16": 20, "bf16": 20,
"iq2_xxs": 21, "gguf_iq2_xxs": 21,
"iq2_xs": 22, "gguf_iq2_xs": 22,
"q2_k": 23} "q2_k": 23}
_llama_quantize_type = {"q4_0": 2, _llama_quantize_type = {"q4_0": 2,

View file

@ -70,8 +70,8 @@ FP4 = ggml_tensor_qtype["fp4"]
MOFQ4 = ggml_tensor_qtype["mixed_fp4"] MOFQ4 = ggml_tensor_qtype["mixed_fp4"]
MOFQ8 = ggml_tensor_qtype["mixed_fp8"] MOFQ8 = ggml_tensor_qtype["mixed_fp8"]
FP8E5 = ggml_tensor_qtype["fp8_e5m2"] FP8E5 = ggml_tensor_qtype["fp8_e5m2"]
IQ2_XXS = ggml_tensor_qtype["iq2_xxs"] IQ2_XXS = ggml_tensor_qtype["gguf_iq2_xxs"]
IQ2_XS = ggml_tensor_qtype["iq2_xs"] IQ2_XS = ggml_tensor_qtype["gguf_iq2_xs"]
Q2_K = ggml_tensor_qtype["q2_k"] Q2_K = ggml_tensor_qtype["q2_k"]

View file

@ -110,7 +110,7 @@ class _BaseAutoModelClass:
:param load_in_low_bit: str value, options are ``'sym_int4'``, ``'asym_int4'``, :param load_in_low_bit: str value, options are ``'sym_int4'``, ``'asym_int4'``,
``'sym_int5'``, ``'asym_int5'``, ``'sym_int8'``, ``'nf3'``, ``'sym_int5'``, ``'asym_int5'``, ``'sym_int8'``, ``'nf3'``,
``'nf4'``, ``'fp4'``, ``'fp8'``, ``'fp8_e4m3'``, ``'fp8_e5m2'``, ``'nf4'``, ``'fp4'``, ``'fp8'``, ``'fp8_e4m3'``, ``'fp8_e5m2'``,
``'iq2_xxs'``, ``'iq2_xs'``, ``'fp16'`` or ``'bf16'``, ``'gguf_iq2_xxs'``, ``'gguf_iq2_xs'``, ``'fp16'`` or ``'bf16'``,
``'sym_int4'`` means symmetric int 4, ``'asym_int4'`` means ``'sym_int4'`` means symmetric int 4, ``'asym_int4'`` means
asymmetric int 4, ``'nf4'`` means 4-bit NormalFloat, etc. asymmetric int 4, ``'nf4'`` means 4-bit NormalFloat, etc.
Relevant low bit optimizations will be applied to the model. Relevant low bit optimizations will be applied to the model.
@ -278,12 +278,13 @@ class _BaseAutoModelClass:
kwargs["pretraining_tp"] = 1 kwargs["pretraining_tp"] = 1
q_k = load_in_low_bit if load_in_low_bit else "sym_int4" q_k = load_in_low_bit if load_in_low_bit else "sym_int4"
imatrix_file = kwargs.pop("imatrix", None) imatrix_file = kwargs.pop("imatrix", None)
if q_k in ["iq2_xxs", "iq2_xs"]: if q_k in ["gguf_iq2_xxs", "gguf_iq2_xs"]:
invalidInputError(imatrix_file is not None, invalidInputError(imatrix_file is not None,
"For iq2_xxs and iq2_xs quantization, imatrix is needed.") "For gguf_iq2_xxs and gguf_iq2_xs quantization,"
"imatrix is needed.")
cpu_embedding = kwargs.get("cpu_embedding", False) cpu_embedding = kwargs.get("cpu_embedding", False)
# for 2bit, default use embedding_quantization # for 2bit, default use embedding_quantization
if q_k in ["iq2_xxs", "iq2_xs", "q2_k"] and not cpu_embedding and \ if q_k in ["gguf_iq2_xxs", "gguf_iq2_xs", "q2_k"] and not cpu_embedding and \
embedding_qtype is None: embedding_qtype is None:
embedding_qtype = "q2_k" embedding_qtype = "q2_k"
if imatrix_file is not None: if imatrix_file is not None:

View file

@ -269,7 +269,7 @@ def module_name_process(full_module_name):
def get_cur_qtype_and_imatrix(qtype, full_module_name, imatrix_data, model_type=None): def get_cur_qtype_and_imatrix(qtype, full_module_name, imatrix_data, model_type=None):
cur_qtype = qtype cur_qtype = qtype
if qtype in [ggml_tensor_qtype["iq2_xxs"], ggml_tensor_qtype["iq2_xs"]]: if qtype in [ggml_tensor_qtype["gguf_iq2_xxs"], ggml_tensor_qtype["gguf_iq2_xs"]]:
# For quantization which needs importance matrix # For quantization which needs importance matrix
new_module_name, layer, cur_module = module_name_process(full_module_name) new_module_name, layer, cur_module = module_name_process(full_module_name)
# custom mixed quantization strategy # custom mixed quantization strategy