Add awq load support (#9453)
* Support directly loading GPTQ models from huggingface * fix style * fix tests * change example structure * address comments * fix style * init * address comments * add examples * fix style * fix style * fix style * fix style * update * remove * meet comments * fix style --------- Co-authored-by: Yang Wang <yang3.wang@intel.com>
This commit is contained in:
parent
d2c064124a
commit
d5263e6681
11 changed files with 1090 additions and 19 deletions
|
|
@ -0,0 +1,71 @@
|
|||
# AWQ
|
||||
This example shows how to directly run 4-bit AWQ models using BigDL-LLM on Intel CPU. For illustration purposes, we utilize the ["TheBloke/Llama-2-7B-Chat-AWQ"](https://huggingface.co/TheBloke/Llama-2-7B-Chat-AWQ) as a reference.
|
||||
|
||||
## 0. Requirements
|
||||
To run these examples with BigDL-LLM, we have some recommended requirements for your machine, please refer to [here](../README.md#recommended-requirements) for more information.
|
||||
|
||||
## Example: Predict Tokens using `generate()` API
|
||||
In the example [generate.py](./generate.py), we show a basic use case for a Llama2 model to predict the next N tokens using `generate()` API, with BigDL-LLM INT4 optimizations.
|
||||
### 1. Install
|
||||
We suggest using conda to manage environment:
|
||||
```bash
|
||||
conda create -n llm python=3.9
|
||||
conda activate llm
|
||||
|
||||
pip install autoawq==0.1.6 --no-deps
|
||||
pip install bigdl-llm[all] # install bigdl-llm with 'all' option
|
||||
pip install transformers==4.35.0
|
||||
pip install accelerate==0.24.1
|
||||
```
|
||||
|
||||
### 2. Run
|
||||
```
|
||||
python ./generate.py --repo-id-or-model-path REPO_ID_OR_MODEL_PATH --prompt PROMPT --n-predict N_PREDICT
|
||||
```
|
||||
|
||||
Arguments info:
|
||||
- `--repo-id-or-model-path REPO_ID_OR_MODEL_PATH`: argument defining the huggingface repo id for the Llama2-awq model (e.g. `TheBloke/Llama-2-7B-Chat-AWQ`) to be downloaded, or the path to the huggingface checkpoint folder. It is default to be `'TheBloke/Llama-2-7B-Chat-AWQ'`.
|
||||
- `--prompt PROMPT`: argument defining the prompt to be infered (with integrated prompt format for chat). It is default to be `'What is AI?'`.
|
||||
- `--n-predict N_PREDICT`: argument defining the max number of tokens to predict. It is default to be `32`.
|
||||
|
||||
> **Note**: When loading the model in 4-bit, BigDL-LLM converts linear layers in the model into INT4 format. In theory, a *X*B model saved in 16-bit will requires approximately 2*X* GB of memory for loading, and ~0.5*X* GB memory for further inference.
|
||||
>
|
||||
> Please select the appropriate size of the Llama2 model based on the capabilities of your machine.
|
||||
|
||||
#### 2.1 Client
|
||||
On client Windows machine, it is recommended to run directly with full utilization of all cores:
|
||||
```powershell
|
||||
python ./generate.py
|
||||
```
|
||||
|
||||
#### 2.2 Server
|
||||
For optimal performance on server, it is recommended to set several environment variables (refer to [here](../README.md#best-known-configuration-on-linux) for more information), and run the example with all the physical cores of a single socket.
|
||||
|
||||
E.g. on Linux,
|
||||
```bash
|
||||
# set BigDL-Nano env variables
|
||||
source bigdl-llm-init
|
||||
|
||||
# e.g. for a server with 48 cores per socket
|
||||
export OMP_NUM_THREADS=48
|
||||
numactl -C 0-47 -m 0 python ./generate.py
|
||||
```
|
||||
|
||||
#### 2.3 Sample Output
|
||||
#### ["TheBloke/Llama-2-7B-Chat-AWQ"](https://huggingface.co/TheBloke/Llama-2-7B-Chat-AWQ)
|
||||
```log
|
||||
Inference time: xxxx s
|
||||
-------------------- Prompt --------------------
|
||||
### HUMAN:
|
||||
What is AI?
|
||||
|
||||
### RESPONSE:
|
||||
|
||||
-------------------- Output --------------------
|
||||
### HUMAN:
|
||||
What is AI?
|
||||
|
||||
### RESPONSE:
|
||||
|
||||
Artificial intelligence (AI) is the ability of machines to perform tasks that typically require human intelligence, such as learning, problem-solving, decision
|
||||
```
|
||||
|
|
@ -0,0 +1,71 @@
|
|||
#
|
||||
# Copyright 2016 The BigDL Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import torch
|
||||
import time
|
||||
import argparse
|
||||
|
||||
from bigdl.llm.transformers import AutoModelForCausalLM
|
||||
from transformers import LlamaTokenizer
|
||||
|
||||
# you could tune the prompt based on your own model,
|
||||
# here the prompt tuning refers to https://huggingface.co/georgesung/llama2_7b_chat_uncensored#prompt-style
|
||||
LLAMA2_PROMPT_FORMAT = """### HUMAN:
|
||||
{prompt}
|
||||
|
||||
### RESPONSE:
|
||||
"""
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for Llama2 model')
|
||||
parser.add_argument('--repo-id-or-model-path', type=str, default="TheBloke/Llama-2-7B-Chat-AWQ",
|
||||
help='The huggingface repo id'
|
||||
', or the path to the huggingface checkpoint folder')
|
||||
parser.add_argument('--prompt', type=str, default="What is AI?",
|
||||
help='Prompt to infer')
|
||||
parser.add_argument('--n-predict', type=int, default=32,
|
||||
help='Max tokens to predict')
|
||||
|
||||
args = parser.parse_args()
|
||||
model_path = args.repo_id_or_model_path
|
||||
|
||||
# Load model in 4 bit,
|
||||
# which convert the relevant layers in the model into INT4 format
|
||||
model = AutoModelForCausalLM.from_pretrained(model_path,
|
||||
load_in_4bit=True,
|
||||
trust_remote_code=True)
|
||||
|
||||
# Load tokenizer
|
||||
tokenizer = LlamaTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
|
||||
# Generate predicted tokens
|
||||
with torch.inference_mode():
|
||||
prompt = LLAMA2_PROMPT_FORMAT.format(prompt=args.prompt)
|
||||
input_ids = tokenizer.encode(prompt, return_tensors="pt")
|
||||
st = time.time()
|
||||
# if your selected model is capable of utilizing previous key/value attentions
|
||||
# to enhance decoding speed, but has `"use_cache": false` in its model config,
|
||||
# it is important to set `use_cache=True` explicitly in the `generate` function
|
||||
# to obtain optimal performance with BigDL-LLM INT4 optimizations
|
||||
output = model.generate(input_ids,
|
||||
max_new_tokens=args.n_predict)
|
||||
end = time.time()
|
||||
output_str = tokenizer.decode(output[0], skip_special_tokens=True)
|
||||
print(f'Inference time: {end-st} s')
|
||||
print('-'*20, 'Prompt', '-'*20)
|
||||
print(prompt)
|
||||
print('-'*20, 'Output', '-'*20)
|
||||
print(output_str)
|
||||
|
|
@ -0,0 +1,65 @@
|
|||
# AWQ
|
||||
This example shows how to directly run 4-bit AWQ models using BigDL-LLM on Intel GPU. For illustration purposes, we utilize the ["TheBloke/Llama-2-7B-Chat-AWQ"](https://huggingface.co/TheBloke/Llama-2-7B-Chat-AWQ) as a reference.
|
||||
|
||||
## 0. Requirements
|
||||
To run these examples with BigDL-LLM, we have some recommended requirements for your machine, please refer to [here](../README.md#recommended-requirements) for more information.
|
||||
|
||||
## Example: Predict Tokens using `generate()` API
|
||||
In the example [generate.py](./generate.py), we show a basic use case for a Llama2 model to predict the next N tokens using `generate()` API, with BigDL-LLM INT4 optimizations.
|
||||
### 1. Install
|
||||
We suggest using conda to manage environment:
|
||||
```bash
|
||||
conda create -n llm python=3.9
|
||||
conda activate llm
|
||||
|
||||
pip install --pre --upgrade bigdl-llm[xpu] -f https://developer.intel.com/ipex-whl-stable-xpu
|
||||
pip install transformers==4.35.0
|
||||
pip install autoawq==0.1.6 --no-deps
|
||||
pip install accelerate==0.24.1
|
||||
```
|
||||
|
||||
### 2. Configures OneAPI environment variables
|
||||
```bash
|
||||
source /opt/intel/oneapi/setvars.sh
|
||||
```
|
||||
|
||||
### 3. Run
|
||||
|
||||
For optimal performance on Arc, it is recommended to set several environment variables.
|
||||
|
||||
```bash
|
||||
export USE_XETLA=OFF
|
||||
export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1
|
||||
```
|
||||
|
||||
```
|
||||
python ./generate.py --repo-id-or-model-path REPO_ID_OR_MODEL_PATH --prompt PROMPT --n-predict N_PREDICT
|
||||
```
|
||||
|
||||
Arguments info:
|
||||
- `--repo-id-or-model-path REPO_ID_OR_MODEL_PATH`: argument defining the huggingface repo id for the Llama2-awq model (e.g. `TheBloke/Llama-2-7B-Chat-AWQ`) to be downloaded, or the path to the huggingface checkpoint folder. It is default to be `'TheBloke/Llama-2-7B-Chat-AWQ'`.
|
||||
- `--prompt PROMPT`: argument defining the prompt to be infered (with integrated prompt format for chat). It is default to be `'What is AI?'`.
|
||||
- `--n-predict N_PREDICT`: argument defining the max number of tokens to predict. It is default to be `32`.
|
||||
|
||||
> **Note**: When loading the model in 4-bit, BigDL-LLM converts linear layers in the model into INT4 format. In theory, a *X*B model saved in 16-bit will requires approximately 2*X* GB of memory for loading, and ~0.5*X* GB memory for further inference.
|
||||
>
|
||||
> Please select the appropriate size of the Llama2 model based on the capabilities of your machine.
|
||||
|
||||
#### 2.3 Sample Output
|
||||
#### ["TheBloke/Llama-2-7B-Chat-AWQ"](https://huggingface.co/TheBloke/Llama-2-7B-Chat-AWQ)
|
||||
```log
|
||||
Inference time: xxxx s
|
||||
-------------------- Prompt --------------------
|
||||
### HUMAN:
|
||||
What is AI?
|
||||
|
||||
### RESPONSE:
|
||||
|
||||
-------------------- Output --------------------
|
||||
### HUMAN:
|
||||
What is AI?
|
||||
|
||||
### RESPONSE:
|
||||
|
||||
Artificial intelligence (AI) is the ability of machines to perform tasks that typically require human intelligence, such as learning, problem-solving, decision
|
||||
```
|
||||
|
|
@ -0,0 +1,71 @@
|
|||
#
|
||||
# Copyright 2016 The BigDL Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import torch
|
||||
import time
|
||||
import argparse
|
||||
import intel_extension_for_pytorch as ipex
|
||||
from bigdl.llm.transformers import AutoModelForCausalLM
|
||||
from transformers import LlamaTokenizer
|
||||
|
||||
# you could tune the prompt based on your own model,
|
||||
# here the prompt tuning refers to https://huggingface.co/georgesung/llama2_7b_chat_uncensored#prompt-style
|
||||
LLAMA2_PROMPT_FORMAT = """### HUMAN:
|
||||
{prompt}
|
||||
|
||||
### RESPONSE:
|
||||
"""
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for Llama2 model')
|
||||
parser.add_argument('--repo-id-or-model-path', type=str, default="TheBloke/Llama-2-7B-Chat-AWQ",
|
||||
help='The huggingface repo id'
|
||||
', or the path to the huggingface checkpoint folder')
|
||||
parser.add_argument('--prompt', type=str, default="What is AI?",
|
||||
help='Prompt to infer')
|
||||
parser.add_argument('--n-predict', type=int, default=32,
|
||||
help='Max tokens to predict')
|
||||
|
||||
args = parser.parse_args()
|
||||
model_path = args.repo_id_or_model_path
|
||||
|
||||
# Load model in 4 bit,
|
||||
# which convert the relevant layers in the model into INT4 format
|
||||
model = AutoModelForCausalLM.from_pretrained(model_path,
|
||||
load_in_4bit=True,
|
||||
trust_remote_code=True,).to("xpu")
|
||||
|
||||
# Load tokenizer
|
||||
tokenizer = LlamaTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
|
||||
# Generate predicted tokens
|
||||
with torch.inference_mode():
|
||||
prompt = LLAMA2_PROMPT_FORMAT.format(prompt=args.prompt)
|
||||
input_ids = tokenizer.encode(prompt, return_tensors="pt").to("xpu")
|
||||
st = time.time()
|
||||
# if your selected model is capable of utilizing previous key/value attentions
|
||||
# to enhance decoding speed, but has `"use_cache": false` in its model config,
|
||||
# it is important to set `use_cache=True` explicitly in the `generate` function
|
||||
# to obtain optimal performance with BigDL-LLM INT4 optimizations
|
||||
output = model.generate(input_ids,
|
||||
max_new_tokens=args.n_predict)
|
||||
end = time.time()
|
||||
output_str = tokenizer.decode(output[0], skip_special_tokens=True)
|
||||
print(f'Inference time: {end-st} s')
|
||||
print('-'*20, 'Prompt', '-'*20)
|
||||
print(prompt)
|
||||
print('-'*20, 'Output', '-'*20)
|
||||
print(output_str)
|
||||
21
python/llm/src/bigdl/llm/transformers/awq/__init__.py
Normal file
21
python/llm/src/bigdl/llm/transformers/awq/__init__.py
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
#
|
||||
# Copyright 2016 The BigDL Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
# This would makes sure Python is aware there is more than one sub-package within bigdl,
|
||||
# physically located elsewhere.
|
||||
# Otherwise there would be module not found error in non-pip's setting as Python would
|
||||
# only search the first bigdl package and end up finding only one sub-package.
|
||||
|
||||
54
python/llm/src/bigdl/llm/transformers/awq/act.py
Normal file
54
python/llm/src/bigdl/llm/transformers/awq/act.py
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
#
|
||||
# Copyright 2016 The BigDL Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# ===========================================================================
|
||||
#
|
||||
# This file is copied from
|
||||
# https://github.com/casper-hansen/AutoAWQ/blob/main/awq/modules/act.py
|
||||
#
|
||||
# MIT License
|
||||
#
|
||||
# Copyright (c) 2023 MIT HAN Lab
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
#
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class ScaledActivation(nn.Module):
|
||||
def __init__(self, module, scales):
|
||||
super().__init__()
|
||||
self.act = module
|
||||
self.scales = nn.Parameter(scales.data)
|
||||
|
||||
def forward(self, x):
|
||||
return self.act(x) / self.scales.view(1, 1, -1).to(x.device)
|
||||
223
python/llm/src/bigdl/llm/transformers/awq/awq.py
Normal file
223
python/llm/src/bigdl/llm/transformers/awq/awq.py
Normal file
|
|
@ -0,0 +1,223 @@
|
|||
#
|
||||
# Copyright 2016 The BigDL Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# ===========================================================================
|
||||
#
|
||||
# This file is adapted from
|
||||
# https://github.com/casper-hansen/AutoAWQ/blob/main/awq/models/base.py#L147
|
||||
# and https://github.com/mit-han-lab/llm-awq/blob/main/awq/quantize/quantizer.py
|
||||
#
|
||||
# MIT License
|
||||
#
|
||||
# Copyright (c) 2023 MIT HAN Lab
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
#
|
||||
|
||||
import gc
|
||||
import os
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers.models.bloom.modeling_bloom import BloomForCausalLM
|
||||
from transformers.models.opt.modeling_opt import OPTForCausalLM
|
||||
from transformers.models.llama.modeling_llama import LlamaForCausalLM
|
||||
from transformers.models.bloom.modeling_bloom import BloomBlock
|
||||
from transformers import AwqConfig, AutoConfig
|
||||
from bigdl.llm.transformers.awq.linear import WQLinear_GEMM, WQLinear_GEMV
|
||||
from huggingface_hub import snapshot_download
|
||||
from bigdl.llm.utils.common import invalidInputError
|
||||
|
||||
|
||||
layer_type_dict = {
|
||||
"mpt": "MPTBlock",
|
||||
"llama": "LlamaDecoderLayer",
|
||||
"opt": "OPTDecoderLayer",
|
||||
"RefinedWeb": "FalconDecoderLayer",
|
||||
"RefinedWebModel": "FalconDecoderLayer",
|
||||
"falcon": "FalconDecoderLayer",
|
||||
"bloom": "BloomBlock",
|
||||
"gptj": "GPTJBlock",
|
||||
"gpt_bigcode": "GPTBigCodeBlock",
|
||||
"mistral": "MistralDecoderLayer",
|
||||
"gpt_neox": "GPTNeoXDecoderLayer",
|
||||
"aquila": "AquilaDecoderLayer",
|
||||
}
|
||||
|
||||
|
||||
def set_op_by_name(layer, name, new_module):
|
||||
levels = name.split('.')
|
||||
if len(levels) > 1:
|
||||
mod_ = layer
|
||||
for l_idx in range(len(levels)-1):
|
||||
if levels[l_idx].isdigit():
|
||||
mod_ = mod_[int(levels[l_idx])]
|
||||
else:
|
||||
mod_ = getattr(mod_, levels[l_idx])
|
||||
setattr(mod_, levels[-1], new_module)
|
||||
else:
|
||||
setattr(layer, name, new_module)
|
||||
|
||||
|
||||
def _load_config(model_path, model_filename, safetensors=False,
|
||||
trust_remote_code=True, max_new_tokens=4096):
|
||||
# [STEP 1] Download model if path is not a directory
|
||||
if not os.path.isdir(model_path):
|
||||
ignore_patterns = ["*msgpack*", "*h5*"]
|
||||
if safetensors:
|
||||
ignore_patterns.extend(["*.pt*", "*.bin*"])
|
||||
else:
|
||||
ignore_patterns.append("*.safetensors*")
|
||||
|
||||
model_path = snapshot_download(model_path, ignore_patterns=ignore_patterns)
|
||||
|
||||
if model_filename != '':
|
||||
model_weights_path = model_path + f'/{model_filename}'
|
||||
else:
|
||||
model_weights_path = model_path
|
||||
|
||||
# Load model config and set max generation length
|
||||
max_new_tokens = 2048 if max_new_tokens is None else max_new_tokens
|
||||
config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code)
|
||||
config.max_new_tokens = max_new_tokens
|
||||
|
||||
return model_weights_path, config
|
||||
|
||||
|
||||
def get_named_linears(module):
|
||||
return {name: m for name, m in module.named_modules() if isinstance(m, nn.Linear)}
|
||||
|
||||
|
||||
def get_blocks(model):
|
||||
if model.__class__.__name__ == 'LlamaForCausalLM':
|
||||
layers = model.model.layers
|
||||
elif isinstance(model, OPTForCausalLM):
|
||||
layers = model.model.decoder.layers
|
||||
elif isinstance(model, BloomForCausalLM):
|
||||
layers = model.transformer.h
|
||||
elif "mpt" in str(model.__class__).lower():
|
||||
layers = model.transformer.blocks
|
||||
elif "falcon" in str(model.__class__).lower():
|
||||
layers = model.transformer.h
|
||||
elif "bigcode" in str(model.__class__).lower():
|
||||
layers = model.transformer.h
|
||||
elif "neox" in str(model.__class__).lower():
|
||||
layers = model.gpt_neox.layers
|
||||
else:
|
||||
invalidInputError(False, f"Model type {type(model)} isn't supported.")
|
||||
return layers
|
||||
|
||||
|
||||
def get_layer_type(config):
|
||||
if config.model_type not in layer_type_dict.keys():
|
||||
invalidInputError(False, f"{config.model_type} isn't supported yet.")
|
||||
return layer_type_dict[config.model_type]
|
||||
|
||||
|
||||
def scale_activations(module):
|
||||
from bigdl.llm.transformers.awq.act import ScaledActivation
|
||||
param = next(module.parameters())
|
||||
dtype = param.dtype
|
||||
device = param.device
|
||||
if isinstance(module, BloomBlock):
|
||||
if isinstance(module.mlp.gelu_impl, ScaledActivation):
|
||||
return
|
||||
c = module.mlp.dense_h_to_4h.out_features
|
||||
act = ScaledActivation(
|
||||
module.mlp.gelu_impl,
|
||||
torch.ones(c, dtype=dtype, device=device)
|
||||
)
|
||||
set_op_by_name(module, "mlp.gelu_impl", act)
|
||||
elif 'mptblock' in str(module.__class__.__name__).lower():
|
||||
if isinstance(module.ffn.act, ScaledActivation):
|
||||
return
|
||||
c = module.ffn.up_proj.out_features
|
||||
act = ScaledActivation(
|
||||
module.ffn.act,
|
||||
torch.ones(c, dtype=dtype, device=device)
|
||||
)
|
||||
set_op_by_name(module, "ffn.act", act)
|
||||
elif 'falcon' in str(module.__class__).lower():
|
||||
if isinstance(module.mlp.act, ScaledActivation):
|
||||
return
|
||||
c = module.mlp.dense_h_to_4h.out_features
|
||||
act = ScaledActivation(
|
||||
module.mlp.act,
|
||||
torch.ones(c, dtype=dtype, device=device)
|
||||
)
|
||||
set_op_by_name(module, "mlp.act", act)
|
||||
elif 'bigcode' in str(module.__class__).lower():
|
||||
if isinstance(module.mlp.act, ScaledActivation):
|
||||
return
|
||||
c = module.mlp.c_proj.out_features
|
||||
act = ScaledActivation(
|
||||
module.mlp.act,
|
||||
torch.ones(c, dtype=dtype, device=device)
|
||||
)
|
||||
set_op_by_name(module, "mlp.act", act)
|
||||
elif 'neox' in str(module.__class__).lower():
|
||||
if isinstance(module.mlp.act, ScaledActivation):
|
||||
return
|
||||
c = module.mlp.dense_h_to_4h.out_features
|
||||
act = ScaledActivation(
|
||||
module.mlp.act,
|
||||
torch.ones(c, dtype=dtype, device=device)
|
||||
)
|
||||
set_op_by_name(module, "mlp.act", act)
|
||||
|
||||
|
||||
def _replace_with_awq_layers(model, awq_config: AwqConfig):
|
||||
layers = get_blocks(model)
|
||||
|
||||
for i in tqdm(range(len(layers)), desc="Replacing layers..."):
|
||||
layer = layers[i]
|
||||
|
||||
# Get every linear layer in a block
|
||||
named_linears = get_named_linears(layer)
|
||||
|
||||
# Replace activation functions
|
||||
scale_activations(layer)
|
||||
|
||||
# Replace nn.Linear with WQLinear
|
||||
for name, module in named_linears.items():
|
||||
if awq_config.version == 'gemm':
|
||||
q_linear_module = WQLinear_GEMM
|
||||
elif awq_config.version == 'gemv':
|
||||
q_linear_module = WQLinear_GEMV
|
||||
|
||||
q_linear = q_linear_module.from_linear(module,
|
||||
awq_config.bits,
|
||||
awq_config.group_size,
|
||||
True)
|
||||
q_linear.to(next(layer.parameters()).device)
|
||||
set_op_by_name(layer, name, q_linear)
|
||||
|
||||
gc.collect()
|
||||
99
python/llm/src/bigdl/llm/transformers/awq/awq_config.py
Normal file
99
python/llm/src/bigdl/llm/transformers/awq/awq_config.py
Normal file
|
|
@ -0,0 +1,99 @@
|
|||
#
|
||||
# Copyright 2016 The BigDL Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# ===========================================================================
|
||||
#
|
||||
# This file is copied from
|
||||
# https://github.com/huggingface/transformers
|
||||
#
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from dataclasses import dataclass
|
||||
from bigdl.llm.utils.common import invalidInputError
|
||||
from transformers.utils.quantization_config import QuantizationConfigMixin
|
||||
from transformers.utils.quantization_config import AwqBackendPackingMethod,\
|
||||
AWQLinearVersion, QuantizationMethod
|
||||
|
||||
|
||||
@dataclass
|
||||
class AwqConfig(QuantizationConfigMixin):
|
||||
"""
|
||||
This is a wrapper class about all possible attributes and features that you can
|
||||
play with a model that has been loaded using `auto-awq` library awq quantization
|
||||
relying on auto_awq backend.
|
||||
|
||||
Args:
|
||||
bits (`int`, *optional*, defaults to 4):
|
||||
The number of bits to quantize to.
|
||||
group_size (`int`, *optional*, defaults to 128):
|
||||
The group size to use for quantization.
|
||||
Recommended value is 128 and -1 uses per-column quantization.
|
||||
zero_point (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use zero point quantization.
|
||||
version (`AWQLinearVersion`, *optional*, defaults to
|
||||
`AWQLinearVersion.GEMM`):
|
||||
The version of the quantization algorithm to use.
|
||||
GEMM is better for big batch_size (e.g. >= 8) otherwise,
|
||||
GEMV is better (e.g. < 8 )
|
||||
backend (`AwqBackendPackingMethod`, *optional*, defaults to
|
||||
`AwqBackendPackingMethod.AUTOAWQ`):
|
||||
The quantization backend. Some models might be quantized using `llm-awq` backend.
|
||||
This is useful for users that quantize their own models using `llm-awq` library.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
bits: int = 4,
|
||||
group_size: int = 128,
|
||||
zero_point: bool = True,
|
||||
version: AWQLinearVersion = AWQLinearVersion.GEMM,
|
||||
backend: AwqBackendPackingMethod = AwqBackendPackingMethod.AUTOAWQ,
|
||||
**kwargs,
|
||||
):
|
||||
self.quant_method = QuantizationMethod.AWQ
|
||||
|
||||
self.bits = bits
|
||||
self.group_size = group_size
|
||||
self.zero_point = zero_point
|
||||
self.version = version
|
||||
self.backend = backend
|
||||
|
||||
self.post_init()
|
||||
|
||||
def post_init(self):
|
||||
r"""
|
||||
Safety checker that arguments are correct
|
||||
"""
|
||||
invalidInputError(self.backend == AwqBackendPackingMethod.AUTOAWQ,
|
||||
"Only supported quantization backends in "
|
||||
f"{AwqBackendPackingMethod.AUTOAWQ} - "
|
||||
f"not recognized backend {self.backend}")
|
||||
|
||||
invalidInputError(self.version in [AWQLinearVersion.GEMM, AWQLinearVersion.GEMV],
|
||||
"Only supported versions are in [AWQLinearVersion.GEMM,"
|
||||
f"AWQLinearVersion.GEMV] - not recognized version {self.version}")
|
||||
284
python/llm/src/bigdl/llm/transformers/awq/linear.py
Normal file
284
python/llm/src/bigdl/llm/transformers/awq/linear.py
Normal file
|
|
@ -0,0 +1,284 @@
|
|||
#
|
||||
# Copyright 2016 The BigDL Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# ===========================================================================
|
||||
#
|
||||
# This file is adapted from
|
||||
# https://github.com/casper-hansen/AutoAWQ/blob/main/awq/modules/linear.py
|
||||
#
|
||||
# MIT License
|
||||
#
|
||||
# Copyright (c) 2023 MIT HAN Lab
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
#
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from bigdl.llm.utils.common import invalidOperationError, invalidInputError
|
||||
|
||||
|
||||
def make_divisible(c, divisor):
|
||||
return (c + divisor - 1) // divisor
|
||||
|
||||
|
||||
def calculate_zeros_width(in_features, group_size=128, pack_num=8):
|
||||
if group_size >= 128:
|
||||
size_multiplier = 1
|
||||
elif group_size == 64:
|
||||
size_multiplier = 2
|
||||
elif group_size == 32:
|
||||
size_multiplier = 4
|
||||
else:
|
||||
invalidOperationError(False,
|
||||
f"Not implemented group size {group_size}.")
|
||||
|
||||
base_width = make_divisible(in_features // group_size, pack_num)
|
||||
base_width = make_divisible(base_width, size_multiplier) * size_multiplier
|
||||
return base_width
|
||||
|
||||
|
||||
class WQLinear_GEMM(nn.Module):
|
||||
def __init__(self, bits, group_size, in_features, out_features, bias, dev):
|
||||
super().__init__()
|
||||
|
||||
invalidOperationError(bits == 4, "Only 4-bit are supported for now.")
|
||||
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.bits = bits
|
||||
self.group_size = group_size if group_size != -1 else in_features
|
||||
|
||||
self.wf = (torch.tensor([0, 4, 1, 5, 2, 6, 3, 7],
|
||||
dtype=torch.int32) * self.bits).unsqueeze(0)
|
||||
|
||||
# quick sanity check (make sure aligment)
|
||||
invalidInputError(self.in_features % self.group_size == 0,
|
||||
f"Invalid in_features number {self.in_features}.")
|
||||
invalidInputError(out_features % (32 // self.bits) == 0,
|
||||
f"Invalid out_features number {out_features}.")
|
||||
|
||||
self.register_buffer('qweight',
|
||||
torch.zeros((in_features,
|
||||
out_features // (32 // self.bits)),
|
||||
dtype=torch.int32, device=dev))
|
||||
self.register_buffer('qzeros',
|
||||
torch.zeros((in_features // self.group_size,
|
||||
out_features // (32 // self.bits)),
|
||||
dtype=torch.int32, device=dev))
|
||||
self.register_buffer('scales',
|
||||
torch.zeros((in_features // self.group_size, out_features),
|
||||
dtype=torch.float16, device=dev))
|
||||
if bias:
|
||||
self.register_buffer('bias', torch.zeros((out_features), dtype=torch.float16,
|
||||
device=dev))
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
@classmethod
|
||||
def from_linear(cls, linear, bits, group_size, init_only=False, scales=None, zeros=None):
|
||||
awq_linear = cls(bits, group_size, linear.in_features, linear.out_features,
|
||||
linear.bias is not None, linear.weight.device)
|
||||
if init_only: # just prepare for loading sd
|
||||
return awq_linear
|
||||
|
||||
# need scales and zeros info for real quantization
|
||||
invalidInputError(scales is not None and zeros is not None,
|
||||
"Scales and zeros should not be None.")
|
||||
scale_zeros = zeros * scales
|
||||
|
||||
awq_linear.scales = scales.clone().half()
|
||||
if linear.bias is not None:
|
||||
awq_linear.bias = linear.bias.clone().half()
|
||||
|
||||
pack_num = 32 // awq_linear.bits
|
||||
|
||||
intweight = []
|
||||
for idx in range(awq_linear.in_features):
|
||||
intweight.append(
|
||||
torch.round((linear.weight.data[:, idx] +
|
||||
scale_zeros[idx // group_size]) /
|
||||
awq_linear.scales[idx // group_size]).to(torch.int)[:, None])
|
||||
intweight = torch.cat(intweight, dim=1)
|
||||
intweight = intweight.t().contiguous()
|
||||
intweight = intweight.to(dtype=torch.int32)
|
||||
qweight = torch.zeros((intweight.shape[0],
|
||||
intweight.shape[1] // (32 // awq_linear.bits)),
|
||||
dtype=torch.int32, device=intweight.device)
|
||||
|
||||
torch.set_printoptions(threshold=10_000)
|
||||
print(intweight)
|
||||
|
||||
for col in range(intweight.shape[1] // pack_num):
|
||||
if awq_linear.bits == 4:
|
||||
order_map = [0, 2, 4, 6, 1, 3, 5, 7]
|
||||
else:
|
||||
invalidOperationError(False, "Only 4-bit are supported for now.")
|
||||
for i in range(pack_num):
|
||||
qweight_col = intweight[:, col * pack_num + order_map[i]]
|
||||
qweight[:, col] |= qweight_col << (i * awq_linear.bits)
|
||||
awq_linear.qweight = qweight
|
||||
|
||||
zeros = zeros.to(dtype=torch.int32)
|
||||
qzeros = torch.zeros((zeros.shape[0], zeros.shape[1] // (32 // awq_linear.bits)),
|
||||
dtype=torch.int32, device=zeros.device)
|
||||
|
||||
for col in range(zeros.shape[1] // pack_num):
|
||||
if awq_linear.bits == 4:
|
||||
order_map = [0, 2, 4, 6, 1, 3, 5, 7]
|
||||
else:
|
||||
invalidOperationError(False, "Only 4-bit are supported for now.")
|
||||
for i in range(pack_num):
|
||||
qzero_col = zeros[:, col * pack_num + order_map[i]]
|
||||
qzeros[:, col] |= qzero_col << (i * awq_linear.bits)
|
||||
awq_linear.qzeros = qzeros
|
||||
|
||||
return awq_linear
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, x):
|
||||
invalidOperationError(False, "Bigdl-llm does not support inference awq models directly.")
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return 'in_features={}, out_features={}, bias={}, bits={}, group_size={}'.format(
|
||||
self.in_features, self.out_features, self.bias is not None, self.bits, self.group_size
|
||||
)
|
||||
|
||||
|
||||
class WQLinear_GEMV(nn.Module):
|
||||
def __init__(self, bits, group_size, in_features, out_features, bias, dev):
|
||||
super().__init__()
|
||||
|
||||
invalidOperationError(bits == 4, "Only 4-bit are supported for now.")
|
||||
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.bits = bits
|
||||
self.group_size = group_size if group_size != -1 else in_features
|
||||
self.split_k_iters = 8
|
||||
|
||||
# quick sanity check (make sure aligment)
|
||||
invalidInputError(self.in_features % self.group_size == 0,
|
||||
f"Invalid in_features number {self.in_features}.")
|
||||
invalidInputError(out_features % (32 // self.bits) == 0,
|
||||
f"Invalid out_features number {out_features}.")
|
||||
pack_num = (32 // self.bits)
|
||||
|
||||
self.register_buffer('qweight',
|
||||
torch.zeros((out_features, in_features // pack_num),
|
||||
dtype=torch.int32, device=dev))
|
||||
self.register_buffer('qzeros',
|
||||
torch.zeros((out_features,
|
||||
calculate_zeros_width(in_features,
|
||||
self.group_size)),
|
||||
dtype=torch.int32, device=dev))
|
||||
self.register_buffer('scales',
|
||||
torch.zeros((out_features,
|
||||
calculate_zeros_width(in_features, self.group_size)
|
||||
* pack_num), dtype=torch.float16, device=dev))
|
||||
if bias:
|
||||
self.register_buffer('bias', torch.zeros((out_features),
|
||||
dtype=torch.float16, device=dev))
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
@classmethod
|
||||
def from_linear(cls, linear, bits, group_size, init_only=False, scales=None, zeros=None):
|
||||
awq_linear = cls(bits, group_size, linear.in_features, linear.out_features,
|
||||
linear.bias is not None, linear.weight.device)
|
||||
if init_only: # just prepare for loading sd
|
||||
return awq_linear
|
||||
|
||||
# need scales and zeros info for real quantization
|
||||
invalidInputError(scales is not None and zeros is not None,
|
||||
"Scales and zeros should not be None.")
|
||||
scale_zeros = zeros * scales
|
||||
|
||||
pack_num = 32 // awq_linear.bits
|
||||
qscales = torch.zeros(
|
||||
(scales.shape[0], calculate_zeros_width(linear.in_features, group_size) * pack_num),
|
||||
dtype=torch.float16,
|
||||
device=scales.device
|
||||
)
|
||||
qscales[:, :scales.shape[1]] = scales
|
||||
awq_linear.scales = qscales
|
||||
if linear.bias is not None:
|
||||
awq_linear.bias = linear.bias.clone().half()
|
||||
|
||||
intweight = []
|
||||
for idx in range(awq_linear.in_features):
|
||||
intweight.append(
|
||||
torch.round((linear.weight.data[:, idx] +
|
||||
scale_zeros[:, idx // group_size]) /
|
||||
awq_linear.scales[:, idx // group_size]).to(torch.int)[:, None])
|
||||
intweight = torch.cat(intweight, dim=1)
|
||||
intweight = intweight.to(dtype=torch.int32)
|
||||
qweight = torch.zeros((intweight.shape[0], intweight.shape[1] // 32 * awq_linear.bits),
|
||||
dtype=torch.int32, device=intweight.device)
|
||||
|
||||
for col in range(intweight.shape[1] // pack_num):
|
||||
if awq_linear.bits == 4:
|
||||
order_map = [0, 1, 2, 3, 4, 5, 6, 7]
|
||||
else:
|
||||
invalidOperationError(False, "Only 4-bit are supported for now.")
|
||||
for i in range(pack_num):
|
||||
qweight_col = intweight[:, col * pack_num + order_map[i]]
|
||||
qweight[:, col] |= qweight_col << (i * awq_linear.bits)
|
||||
awq_linear.qweight = qweight
|
||||
|
||||
zeros = zeros.to(dtype=torch.int32)
|
||||
qzeros = torch.zeros(
|
||||
(zeros.shape[0], calculate_zeros_width(linear.in_features, group_size)),
|
||||
dtype=torch.int32,
|
||||
device=zeros.device,
|
||||
)
|
||||
|
||||
for col in range((zeros.shape[1] + pack_num - 1) // pack_num):
|
||||
if awq_linear.bits == 4:
|
||||
order_map = [0, 1, 2, 3, 4, 5, 6, 7]
|
||||
else:
|
||||
invalidOperationError(False, "Only 4-bit are supported for now.")
|
||||
for i in range(pack_num):
|
||||
if col * pack_num + order_map[i] >= zeros.shape[1]:
|
||||
continue
|
||||
qzero_col = zeros[:, col * pack_num + order_map[i]]
|
||||
qzeros[:, col] |= qzero_col << (i * awq_linear.bits)
|
||||
awq_linear.qzeros = qzeros
|
||||
return awq_linear
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, x):
|
||||
invalidOperationError(False, "Bigdl-llm does not support inference awq models directly.")
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return 'in_features={}, out_features={}, bias={}, bits={}, group_size={}'.format(
|
||||
self.in_features, self.out_features, self.bias is not None, self.bits, self.group_size
|
||||
)
|
||||
|
|
@ -53,6 +53,10 @@ def is_auto_gptq_available():
|
|||
return importlib.util.find_spec("auto_gptq") is not None
|
||||
|
||||
|
||||
def is_auto_awq_available():
|
||||
return importlib.util.find_spec("awq") is not None
|
||||
|
||||
|
||||
def is_deepspeed_available():
|
||||
return importlib.util.find_spec("deepspeed") is not None
|
||||
|
||||
|
|
@ -61,18 +65,24 @@ if is_auto_gptq_available():
|
|||
from auto_gptq.utils.peft_utils import QuantLinearCuda, QuantLinearCudaOld
|
||||
|
||||
|
||||
if is_auto_awq_available():
|
||||
from bigdl.llm.transformers.awq.linear import WQLinear_GEMM
|
||||
|
||||
|
||||
def is_linear_module(module):
|
||||
|
||||
in_features = None
|
||||
out_features = None
|
||||
mp_group = None
|
||||
|
||||
is_awq = is_auto_awq_available() and isinstance(module, WQLinear_GEMM)
|
||||
|
||||
if is_auto_gptq_available() and isinstance(module, QuantLinearCudaOld):
|
||||
in_features = module.infeatures
|
||||
out_features = module.outfeatures
|
||||
mp_group = None
|
||||
result = True
|
||||
elif isinstance(module, nn.Linear):
|
||||
elif isinstance(module, nn.Linear) or is_awq:
|
||||
in_features = module.in_features
|
||||
out_features = module.out_features
|
||||
mp_group = None
|
||||
|
|
@ -102,8 +112,7 @@ from bigdl.llm.transformers.low_bit_linear import get_ggml_qk_size
|
|||
Q4_1 = get_ggml_qk_size("asym_int4")
|
||||
|
||||
|
||||
def convert_gptq(module):
|
||||
|
||||
def convert_gptq(module, awq=False):
|
||||
scales = module.scales
|
||||
|
||||
zeros = torch.bitwise_right_shift(
|
||||
|
|
@ -111,14 +120,22 @@ def convert_gptq(module):
|
|||
module.wf.unsqueeze(0)).to(torch.int16 if module.bits == 8 else torch.int8)
|
||||
zeros = torch.bitwise_and(zeros, (2 ** module.bits) - 1)
|
||||
|
||||
zeros = zeros + 1
|
||||
if not awq:
|
||||
zeros = zeros + 1
|
||||
zeros = zeros.reshape(scales.shape)
|
||||
|
||||
weight = torch.bitwise_right_shift(
|
||||
torch.unsqueeze(module.qweight, 1).expand(-1, 32 // module.bits, -1),
|
||||
module.wf.unsqueeze(-1)).to(torch.int8)
|
||||
weight = torch.bitwise_and(weight, (2 ** module.bits) - 1)
|
||||
weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2])
|
||||
if awq:
|
||||
weight = torch.bitwise_right_shift(
|
||||
torch.unsqueeze(module.qweight, 2).expand(-1, -1, 32 // module.bits),
|
||||
module.wf.unsqueeze(0)).to(torch.int16 if module.bits == 8 else torch.int8)
|
||||
weight = torch.bitwise_and(weight, (2 ** module.bits) - 1)
|
||||
weight = weight.reshape(weight.shape[0], weight.shape[1] * weight.shape[2])
|
||||
else:
|
||||
weight = torch.bitwise_right_shift(
|
||||
torch.unsqueeze(module.qweight, 1).expand(-1, 32 // module.bits, -1),
|
||||
module.wf.unsqueeze(-1)).to(torch.int8)
|
||||
weight = torch.bitwise_and(weight, (2 ** module.bits) - 1)
|
||||
weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2])
|
||||
|
||||
# convert weight to ggml format
|
||||
weight = weight.reshape(weight.shape[0]//module.group_size, module.group_size, weight.shape[1])
|
||||
|
|
@ -171,7 +188,9 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
|||
in_features, out_features, mp_group = linear_args
|
||||
with init_empty_weights():
|
||||
new_linear = None
|
||||
if is_auto_gptq_available() and isinstance(module, QuantLinearCudaOld):
|
||||
is_gptq = is_auto_gptq_available() and isinstance(module, QuantLinearCudaOld)
|
||||
is_awq = is_auto_awq_available() and isinstance(module, WQLinear_GEMM)
|
||||
if is_gptq or is_awq:
|
||||
has_bias = module.bias is not None and module.bias.abs().sum() != 0
|
||||
new_linear = LowBitLinear(
|
||||
in_features,
|
||||
|
|
@ -184,7 +203,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
|||
invalidInputError(device_type != "meta",
|
||||
"converting from meta device is not supported")
|
||||
# Copy the weights
|
||||
paramsLowBit = FP4Params(data=convert_gptq(module),
|
||||
paramsLowBit = FP4Params(data=convert_gptq(module, awq=is_awq),
|
||||
requires_grad=False,
|
||||
quantized=True,
|
||||
_shape=(out_features, in_features),
|
||||
|
|
|
|||
|
|
@ -13,6 +13,29 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
#
|
||||
# MIT License
|
||||
#
|
||||
# Copyright (c) 2023 MIT HAN Lab
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
#
|
||||
|
||||
import transformers
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
|
|
@ -123,6 +146,29 @@ class _BaseAutoModelClass:
|
|||
from transformers import GPTQConfig
|
||||
user_quantization_config = GPTQConfig(bits=4, use_exllama=False)
|
||||
kwargs["quantization_config"] = user_quantization_config
|
||||
elif q_config["quant_method"] == "awq":
|
||||
from bigdl.llm.transformers.awq.awq_config import AwqConfig
|
||||
awq_config = AwqConfig.from_dict(q_config)
|
||||
invalidInputError(awq_config.bits == 4,
|
||||
"Only 4-bit awq is supported in bigdl-llm.")
|
||||
invalidInputError(awq_config.version == "gemm",
|
||||
"Only gemm version is supported in bigdl-llm.")
|
||||
invalidInputError(awq_config.backend == "autoawq",
|
||||
"Only autoawq backend is supported in bigdl-llm.")
|
||||
invalidInputError(awq_config.zero_point,
|
||||
"Only awq zero_point = True is supported in bigdl-llm.")
|
||||
if load_in_low_bit is not None:
|
||||
invalidInputError(load_in_low_bit == "asym_int4",
|
||||
"You can only load awq model as aysm_int4 low bit type.")
|
||||
|
||||
load_in_low_bit = "asym_int4"
|
||||
|
||||
if int(awq_config.group_size) % get_ggml_qk_size(load_in_low_bit) != 0:
|
||||
invalidInputError(False,
|
||||
(f"group_size must be divisible by "
|
||||
f"{get_ggml_qk_size(load_in_low_bit)}."))
|
||||
|
||||
kwargs["quantization_config"] = awq_config
|
||||
|
||||
# load int x-bit
|
||||
kwargs["low_cpu_mem_usage"] = True
|
||||
|
|
@ -156,16 +202,63 @@ class _BaseAutoModelClass:
|
|||
# and lead to args missing.
|
||||
modules_to_not_convert = kwargs.pop("modules_to_not_convert", None)
|
||||
replace_embedding = kwargs.pop("replace_embedding", False)
|
||||
quant_config = kwargs.pop("quantization_config", None)
|
||||
_args = copy.deepcopy(args)
|
||||
_kwargs = copy.deepcopy(kwargs)
|
||||
try:
|
||||
model = cls.HF_Model.from_pretrained(*args, **kwargs)
|
||||
except NotImplementedError:
|
||||
logger.info("Failed to load models with `low_cpu_mem_usage` specified, "
|
||||
"will fall to traditional load method with higher memory consumption.")
|
||||
_kwargs["low_cpu_mem_usage"] = False
|
||||
model = cls.HF_Model.from_pretrained(*_args, **_kwargs)
|
||||
model.config.update({"bigdl_lcmu_enabled": False})
|
||||
awq_config = None
|
||||
|
||||
if quant_config and quant_config.quant_method == "awq":
|
||||
# The latest transformers only support cuda version
|
||||
# This load awq ckpt logic is copied from
|
||||
# https://github.com/casper-hansen/AutoAWQ/blob/main/awq/models/base.py#L147
|
||||
from accelerate import init_empty_weights, infer_auto_device_map,\
|
||||
load_checkpoint_in_model
|
||||
from bigdl.llm.transformers.awq.awq import _replace_with_awq_layers,\
|
||||
get_layer_type, _load_config
|
||||
awq_config = quant_config
|
||||
model_weights_path, config = _load_config(args[0], '', max_new_tokens=None,
|
||||
safetensors=True)
|
||||
with init_empty_weights():
|
||||
model = cls.HF_Model.from_config(config=config, trust_remote_code=True)
|
||||
|
||||
_replace_with_awq_layers(model, awq_config=awq_config)
|
||||
|
||||
model.tie_weights()
|
||||
|
||||
# Get device map
|
||||
device_map = infer_auto_device_map(
|
||||
model,
|
||||
no_split_module_classes=[get_layer_type(config)],
|
||||
max_memory=None,
|
||||
dtype=config.torch_dtype
|
||||
)
|
||||
|
||||
# Load checkpoint
|
||||
load_checkpoint_in_model(
|
||||
model,
|
||||
checkpoint=model_weights_path,
|
||||
device_map=device_map,
|
||||
offload_folder=None,
|
||||
dtype=config.torch_dtype
|
||||
)
|
||||
|
||||
# Offloading dispatch
|
||||
from accelerate import dispatch_model
|
||||
model = dispatch_model(
|
||||
model,
|
||||
device_map=device_map,
|
||||
offload_dir=None
|
||||
)
|
||||
else:
|
||||
try:
|
||||
model = cls.HF_Model.from_pretrained(*args, **kwargs)
|
||||
except NotImplementedError:
|
||||
logger.info("Failed to load models with `low_cpu_mem_usage` specified, "
|
||||
"will fall to traditional load method with higher memory consumption.")
|
||||
_kwargs["low_cpu_mem_usage"] = False
|
||||
model = cls.HF_Model.from_pretrained(*_args, **_kwargs)
|
||||
model.config.update({"bigdl_lcmu_enabled": False})
|
||||
|
||||
model = model.to("cpu")
|
||||
model = ggml_convert_low_bit(model, qtype, optimize_model,
|
||||
modules_to_not_convert=modules_to_not_convert,
|
||||
|
|
|
|||
Loading…
Reference in a new issue