[LLM] support ipex arc int4 & add basic llama2 example (#8700)

* first support of xpu

* make it works on gpu

update setup

update

add GPU llama2 examples

add use_optimize flag to disbale optimize for gpu

fix style

update gpu exmaple readme

fix

* update example, and update env

* fix setup to add cpp files

* replace jit with aot to avoid data leak

* rename to bigdl-core-xe

* update installation in example readme
This commit is contained in:
Ruonan Wang 2023-08-09 22:20:32 +08:00 committed by GitHub
parent d03218674a
commit 1a7b698a83
8 changed files with 232 additions and 19 deletions

View file

@ -0,0 +1,15 @@
# BigDL-LLM Transformers INT4 Optimization for Large Language Model on Intel® Arc™ A-Series Graphics
You can use BigDL-LLM to run almost every Huggingface Transformer models with INT4 optimizations on your laptops with Intel® Arc™ A-Series Graphics. This directory contains example scripts to help you quickly get started using BigDL-LLM to run some popular open-source models in the community. Each model has its own dedicated folder, where you can find detailed instructions on how to install and run it.
## Recommended Requirements
To apply Intel® Arc™ A-Series Graphics acceleration, therere several steps for tools installation and environment preparation.
Step 1, only Linux system is supported now, Ubuntu 22.04 is prefered.
Step 2, please refer to our [drive installation](https://dgpu-docs.intel.com/installation-guides/index.html#intel-arc-gpus) for general purpose GPU capabilities.
Step 3, you also need to download and install [Intel® oneAPI Base Toolkit](https://www.intel.com/content/www/us/en/developer/tools/oneapi/base-toolkit-download.html). OneMKL and DPC++ compiler are needed, others are optional.
## Best Known Configuration on Linux
For better performance, it is recommended to set environment variables on Linux:
```bash
export USE_XETLA=OFF
export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1
```

View file

@ -0,0 +1,78 @@
# Llama2
In this directory, you will find examples on how you could apply BigDL-LLM INT4 optimizations on Llama2 models on any Intel® Arc™ A-Series Graphics. For illustration purposes, we utilize the [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf) and [meta-llama/Llama-2-13b-chat-hf](https://huggingface.co/meta-llama/Llama-2-13b-chat-hf) as reference Llama2 models.
## 0. Requirements
To run these examples with BigDL-LLM on Intel® Arc™ A-Series Graphics, we have some recommended requirements for your machine, please refer to [here](../README.md#recommended-requirements) for more information.
## Example: Predict Tokens using `generate()` API
In the example [generate.py](./generate.py), we show a basic use case for a Llama2 model to predict the next N tokens using `generate()` API, with BigDL-LLM INT4 optimizations on Intel® Arc™ A-Series Graphics.
### 1. Install
We suggest using conda to manage environment:
```bash
conda create -n llm python=3.9
conda activate llm
# below command will install intel_extension_for_pytorch==2.0.110+xpu as default
# you can install specific ipex/torch version for your need
pip install bigdl-llm[xpu] -f https://developer.intel.com/ipex-whl-stable-xpu
# download wheel from sourceforge(https://sourceforge.net/projects/analytics-zoo/files/bigdl-llm/bigdl_core_xe-0.0.0-cp39-cp39-linux_x86_64.whl/download), then install it
pip install bigdl_core_xe-0.0.0-cp39-cp39-linux_x86_64.whl
```
### 2. Configures OneAPI environment variables
```bash
source /opt/intel/oneapi/setvars.sh
```
### 3. Run
For optimal performance on Arc, it is recommended to set several environment variables.
```bash
export USE_XETLA=OFF
export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1
```
```
python ./generate.py --repo-id-or-model-path REPO_ID_OR_MODEL_PATH --prompt PROMPT --n-predict N_PREDICT
```
Arguments info:
- `--repo-id-or-model-path REPO_ID_OR_MODEL_PATH`: argument defining the huggingface repo id for the Llama2 model (e.g. `meta-llama/Llama-2-7b-chat-hf` and `meta-llama/Llama-2-13b-chat-hf`) to be downloaded, or the path to the huggingface checkpoint folder. It is default to be `'meta-llama/Llama-2-7b-chat-hf'`.
- `--prompt PROMPT`: argument defining the prompt to be infered (with integrated prompt format for chat). It is default to be `'What is AI?'`.
- `--n-predict N_PREDICT`: argument defining the max number of tokens to predict. It is default to be `32`.
#### Sample Output
#### [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf)
```log
Inference time: xxxx s
-------------------- Prompt --------------------
### HUMAN:
What is AI?
### RESPONSE:
-------------------- Output --------------------
### HUMAN:
What is AI?
### RESPONSE:
AI is a term used to describe the development of computer systems that can perform tasks that typically require human intelligence, such as understanding natural language, recognizing images
```
#### [meta-llama/Llama-2-13b-chat-hf](https://huggingface.co/meta-llama/Llama-2-13b-chat-hf)
```log
Inference time: xxxx s
-------------------- Prompt --------------------
### HUMAN:
What is AI?
### RESPONSE:
-------------------- Output --------------------
### HUMAN:
What is AI?
### RESPONSE:
AI, or artificial intelligence, refers to the ability of machines to perform tasks that would typically require human intelligence, such as learning, problem-solving,
```

View file

@ -0,0 +1,76 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import torch
import time
import argparse
from bigdl.llm.transformers import AutoModelForCausalLM
from transformers import LlamaTokenizer
import intel_extension_for_pytorch as ipex
# you could tune the prompt based on your own model,
# here the prompt tuning refers to https://huggingface.co/georgesung/llama2_7b_chat_uncensored#prompt-style
LLAMA2_PROMPT_FORMAT = """### HUMAN:
{prompt}
### RESPONSE:
"""
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for Llama2 model')
parser.add_argument('--repo-id-or-model-path', type=str, default="meta-llama/Llama-2-7b-chat-hf",
help='The huggingface repo id for the Llama2 (e.g. `meta-llama/Llama-2-7b-chat-hf` and `meta-llama/Llama-2-13b-chat-hf`) to be downloaded'
', or the path to the huggingface checkpoint folder')
parser.add_argument('--prompt', type=str, default="What is AI?",
help='Prompt to infer')
parser.add_argument('--n-predict', type=int, default=32,
help='Max tokens to predict')
args = parser.parse_args()
model_path = args.repo_id_or_model_path
# Load model in 4 bit,
# which convert the relevant layers in the model into INT4 format
model = AutoModelForCausalLM.from_pretrained(model_path,
load_in_4bit=True,
optimize_model=False,
trust_remote_code=True)
model = model.half().to('xpu')
# Load tokenizer
tokenizer = LlamaTokenizer.from_pretrained(model_path, trust_remote_code=True)
# Generate predicted tokens
with torch.inference_mode():
prompt = LLAMA2_PROMPT_FORMAT.format(prompt=args.prompt)
input_ids = tokenizer.encode(prompt, return_tensors="pt").to('xpu')
st = time.time()
# if your selected model is capable of utilizing previous key/value attentions
# to enhance decoding speed, but has `"use_cache": false` in its model config,
# it is important to set `use_cache=True` explicitly in the `generate` function
# to obtain optimal performance with BigDL-LLM INT4 optimizations
output = model.generate(input_ids,
max_new_tokens=args.n_predict)
torch.xpu.synchronize()
end = time.time()
output = output.cpu()
output_str = tokenizer.decode(output[0], skip_special_tokens=True)
print(f'Inference time: {end-st} s')
print('-'*20, 'Prompt', '-'*20)
print(prompt)
print('-'*20, 'Output', '-'*20)
print(output_str)

View file

@ -34,6 +34,7 @@ import urllib.request
import requests import requests
import re import re
import glob import glob
import copy
from setuptools import setup from setuptools import setup
@ -247,6 +248,13 @@ def setup_package():
all_requires = ['py-cpuinfo'] all_requires = ['py-cpuinfo']
all_requires += CONVERT_DEP all_requires += CONVERT_DEP
# install with -f https://developer.intel.com/ipex-whl-stable-xpu
xpu_requires = copy.deepcopy(all_requires)
xpu_requires.remove('torch')
xpu_requires += ["torch==2.0.1a0",
"torchvision==0.15.2a0",
"intel_extension_for_pytorch==2.0.110+xpu;platform_system=='Linux'"]
metadata = dict( metadata = dict(
name='bigdl-llm', name='bigdl-llm',
version=VERSION, version=VERSION,
@ -267,7 +275,8 @@ def setup_package():
'llm-convert=bigdl.llm.convert_model:main' 'llm-convert=bigdl.llm.convert_model:main'
] ]
}, },
extras_require={"all": all_requires}, extras_require={"all": all_requires,
"xpu": xpu_requires},
classifiers=[ classifiers=[
'License :: OSI Approved :: Apache Software License', 'License :: OSI Approved :: Apache Software License',
'Programming Language :: Python :: 3', 'Programming Language :: Python :: 3',

View file

@ -73,7 +73,6 @@ def _replace_with_quant_linear(model, qtype, modules_to_not_convert=None,
# Check if the current key is not in the `modules_to_not_convert` # Check if the current key is not in the `modules_to_not_convert`
if not any(key in ".".join(current_key_name) for key in modules_to_not_convert): if not any(key in ".".join(current_key_name) for key in modules_to_not_convert):
with init_empty_weights(): with init_empty_weights():
new_linear = LinearQuant( new_linear = LinearQuant(
module.in_features, module.in_features,
module.out_features, module.out_features,
@ -112,7 +111,7 @@ def _replace_with_quant_linear(model, qtype, modules_to_not_convert=None,
return model, has_been_replaced return model, has_been_replaced
def ggml_convert_quant(model, qtype, convert_shape_only=False): def ggml_convert_quant(model, qtype, optimize_model=True, convert_shape_only=False):
modules_to_not_convert = [] # ["lm_head"] modules_to_not_convert = [] # ["lm_head"]
model, has_been_replaced = _replace_with_quant_linear( model, has_been_replaced = _replace_with_quant_linear(
model, qtype, modules_to_not_convert, None, convert_shape_only=convert_shape_only model, qtype, modules_to_not_convert, None, convert_shape_only=convert_shape_only
@ -127,7 +126,8 @@ def ggml_convert_quant(model, qtype, convert_shape_only=False):
else: else:
model.to(torch.float32) model.to(torch.float32)
model = optimize(model) if optimize_model:
model = optimize(model)
return model return model

View file

@ -43,7 +43,7 @@
from typing import Optional, TypeVar, Union, overload from typing import Optional, TypeVar, Union, overload
from bigdl.llm.utils.common import invalidInputError from bigdl.llm.utils.common import invalidInputError
import os
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import Tensor, device, dtype, nn from torch import Tensor, device, dtype, nn
@ -52,8 +52,6 @@ T = TypeVar("T", bound="torch.nn.Module")
import bigdl.llm.ggml.model.llama.llama_cpp as ggml import bigdl.llm.ggml.model.llama.llama_cpp as ggml
from bigdl.llm.utils.isa_checker import is_server from bigdl.llm.utils.isa_checker import is_server
import torch
import ctypes import ctypes
from bigdl.llm.ggml.quantize import ggml_tensor_qtype from bigdl.llm.ggml.quantize import ggml_tensor_qtype
IS_SERVER = is_server() IS_SERVER = is_server()
@ -152,6 +150,17 @@ class ParamsQuant(torch.nn.Parameter):
if (device is not None and device.type == "cpu" and self.data.device.type == "cpu"): if (device is not None and device.type == "cpu" and self.data.device.type == "cpu"):
return self.quantize(device) return self.quantize(device)
elif (device is not None and device.type == "xpu" and self.data.device.type == "cpu"):
# enter xpu logic, compile linear_int4 extension at first time
q_tensor = self.quantize(device) # tensor is cpu now
new_param = ParamsQuant(super().to(device=device,
dtype=dtype,
non_blocking=non_blocking),
requires_grad=self.requires_grad,
quantized=self.quantized,
_shape=self._shape,
qtype=self.qtype)
return new_param
else: else:
new_param = ParamsQuant(super().to(device=device, new_param = ParamsQuant(super().to(device=device,
dtype=dtype, dtype=dtype,
@ -224,15 +233,34 @@ class LinearQuant(nn.Linear):
x0 = self.weight.data x0 = self.weight.data
# todo may need to set a different number on different platforms if x0.device.type == "xpu":
if IS_SERVER and self.qtype == SYM_INT4 and x_2d.shape[0] >= TORCH_LINEAR_THRESHOLD: # GPU logic
x0_fp32 = ggml_int4_convert_fp32(x0, self.weight_shape, self.weight_length) try:
result = F.linear(x, x0_fp32, self.bias) import intel_extension_for_pytorch
else: import linear_q4_0
result = ggml_matmul_src1_x_src0_t(x0, x_2d, self.weight_shape, self.qtype) except ModuleNotFoundError:
invalidInputError(False,
"Please `pip install bigdl_core_xe` first.")
if x_2d.is_contiguous() is False:
x_2d = x_2d.contiguous()
# input format of linear_q4.forward is 1: input, 2: weight
result = linear_q4_0.forward(x_2d, x0)
new_shape = x_shape[:-1] + (self.out_len,) new_shape = x_shape[:-1] + (self.out_len,)
result = result.view(new_shape) result = result.view(new_shape)
if self.bias is not None: if self.bias is not None:
result += self.bias result += self.bias
else:
# CPU logic
# todo may need to set a different number on different platforms
if IS_SERVER and self.qtype == SYM_INT4 and x_2d.shape[0] >= TORCH_LINEAR_THRESHOLD:
x0_fp32 = ggml_int4_convert_fp32(x0, self.weight_shape, self.weight_length)
result = F.linear(x, x0_fp32, self.bias)
else:
result = ggml_matmul_src1_x_src0_t(x0, x_2d, self.weight_shape, self.qtype)
new_shape = x_shape[:-1] + (self.out_len,)
result = result.view(new_shape)
if self.bias is not None:
result += self.bias
return result.to(x.dtype) return result.to(x.dtype)

View file

@ -68,6 +68,7 @@ class _BaseAutoModelClass:
# we can convert the model to quantized later. # we can convert the model to quantized later.
load_in_4bit = kwargs.pop("load_in_4bit", False) load_in_4bit = kwargs.pop("load_in_4bit", False)
load_in_low_bit = kwargs.pop("load_in_low_bit", None) load_in_low_bit = kwargs.pop("load_in_low_bit", None)
optimize_model = kwargs.pop("optimize_model", True)
if load_in_4bit or load_in_low_bit: if load_in_4bit or load_in_low_bit:
# load int x-bit # load int x-bit
@ -78,7 +79,7 @@ class _BaseAutoModelClass:
if "pretraining_tp" in config_dict: if "pretraining_tp" in config_dict:
kwargs["pretraining_tp"] = 1 kwargs["pretraining_tp"] = 1
q_k = load_in_low_bit if load_in_low_bit else "sym_int4" q_k = load_in_low_bit if load_in_low_bit else "sym_int4"
model = cls.load_convert(q_k, *args, **kwargs) model = cls.load_convert(q_k, optimize_model, *args, **kwargs)
else: else:
# load default # load default
model = cls.HF_Model.from_pretrained(*args, **kwargs) model = cls.HF_Model.from_pretrained(*args, **kwargs)
@ -86,7 +87,7 @@ class _BaseAutoModelClass:
return model return model
@classmethod @classmethod
def load_convert(cls, q_k, *args, **kwargs): def load_convert(cls, q_k, optimize_model, *args, **kwargs):
from .convert import ggml_convert_quant from .convert import ggml_convert_quant
invalidInputError(q_k in ggml_tensor_qtype, invalidInputError(q_k in ggml_tensor_qtype,
f"Unknown load_in_low_bit value: {q_k}, expected:" f"Unknown load_in_low_bit value: {q_k}, expected:"
@ -94,7 +95,7 @@ class _BaseAutoModelClass:
qtype = ggml_tensor_qtype[q_k] qtype = ggml_tensor_qtype[q_k]
model = cls.HF_Model.from_pretrained(*args, **kwargs) model = cls.HF_Model.from_pretrained(*args, **kwargs)
model = model.to("cpu") model = model.to("cpu")
model = ggml_convert_quant(model, qtype) model = ggml_convert_quant(model, qtype, optimize_model)
model.config.update({"bigdl_transformers_low_bit": q_k}) model.config.update({"bigdl_transformers_low_bit": q_k})
# add save_low_bit to pretrained model dynamically # add save_low_bit to pretrained model dynamically
@ -128,6 +129,9 @@ class _BaseAutoModelClass:
# set default torch_dtype='auto' # set default torch_dtype='auto'
kwargs["torch_dtype"] = kwargs.get("torch_dtype", 'auto') kwargs["torch_dtype"] = kwargs.get("torch_dtype", 'auto')
# set default optimize_model=True
optimize_model = kwargs.pop("optimize_model", True)
qtype = ggml_tensor_qtype[bigdl_transformers_low_bit] qtype = ggml_tensor_qtype[bigdl_transformers_low_bit]
# Note that the int4 linear layers cannot currently # Note that the int4 linear layers cannot currently
# be recorded in huggingface Pretrained Model or AutoConfig, # be recorded in huggingface Pretrained Model or AutoConfig,
@ -154,7 +158,7 @@ class _BaseAutoModelClass:
# We forcefully modify the model's definition # We forcefully modify the model's definition
# and the tensor shape of int4 weights without quantization. # and the tensor shape of int4 weights without quantization.
model = ggml_convert_quant(model, qtype, convert_shape_only=True) model = ggml_convert_quant(model, qtype, optimize_model, convert_shape_only=True)
# Load the quantized model at last. # Load the quantized model at last.
resolved_archive_file, is_sharded = extract_local_archive_file( resolved_archive_file, is_sharded = extract_local_archive_file(
pretrained_model_name_or_path, pretrained_model_name_or_path,

View file

@ -83,6 +83,7 @@ def llama_attention_forward_4_31(
use_cache: bool = False, use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()
device = hidden_states.device
if self.pretraining_tp > 1: if self.pretraining_tp > 1:
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.pretraining_tp key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.pretraining_tp
@ -153,8 +154,10 @@ def llama_attention_forward_4_31(
past_key_value = (key_states, value_states) if use_cache else None past_key_value = (key_states, value_states) if use_cache else None
# repeat k/v heads if n_kv_heads < n_heads # repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups) key_states = repeat_kv(key_states, self.num_key_value_groups).to(device,
value_states = repeat_kv(value_states, self.num_key_value_groups) dtype=hidden_states.dtype)
value_states = repeat_kv(value_states, self.num_key_value_groups).to(device,
dtype=hidden_states.dtype)
attn_weights = torch.matmul(query_states, attn_weights = torch.matmul(query_states,
key_states.transpose(2, 3)) / math.sqrt(self.head_dim) key_states.transpose(2, 3)) / math.sqrt(self.head_dim)