Add awq load support (#9453)
* Support directly loading GPTQ models from huggingface * fix style * fix tests * change example structure * address comments * fix style * init * address comments * add examples * fix style * fix style * fix style * fix style * update * remove * meet comments * fix style --------- Co-authored-by: Yang Wang <yang3.wang@intel.com>
This commit is contained in:
		
							parent
							
								
									d2c064124a
								
							
						
					
					
						commit
						d5263e6681
					
				
					 11 changed files with 1090 additions and 19 deletions
				
			
		| 
						 | 
				
			
			@ -0,0 +1,71 @@
 | 
			
		|||
# AWQ
 | 
			
		||||
This example shows how to directly run 4-bit AWQ models using BigDL-LLM on Intel CPU. For illustration purposes, we utilize the ["TheBloke/Llama-2-7B-Chat-AWQ"](https://huggingface.co/TheBloke/Llama-2-7B-Chat-AWQ) as a reference.
 | 
			
		||||
 | 
			
		||||
## 0. Requirements
 | 
			
		||||
To run these examples with BigDL-LLM, we have some recommended requirements for your machine, please refer to [here](../README.md#recommended-requirements) for more information.
 | 
			
		||||
 | 
			
		||||
## Example: Predict Tokens using `generate()` API
 | 
			
		||||
In the example [generate.py](./generate.py), we show a basic use case for a Llama2 model to predict the next N tokens using `generate()` API, with BigDL-LLM INT4 optimizations.
 | 
			
		||||
### 1. Install
 | 
			
		||||
We suggest using conda to manage environment:
 | 
			
		||||
```bash
 | 
			
		||||
conda create -n llm python=3.9
 | 
			
		||||
conda activate llm
 | 
			
		||||
 | 
			
		||||
pip install autoawq==0.1.6 --no-deps
 | 
			
		||||
pip install bigdl-llm[all] # install bigdl-llm with 'all' option
 | 
			
		||||
pip install transformers==4.35.0
 | 
			
		||||
pip install accelerate==0.24.1
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
### 2. Run
 | 
			
		||||
```
 | 
			
		||||
python ./generate.py --repo-id-or-model-path REPO_ID_OR_MODEL_PATH --prompt PROMPT --n-predict N_PREDICT
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
Arguments info:
 | 
			
		||||
- `--repo-id-or-model-path REPO_ID_OR_MODEL_PATH`: argument defining the huggingface repo id for the Llama2-awq model (e.g. `TheBloke/Llama-2-7B-Chat-AWQ`) to be downloaded, or the path to the huggingface checkpoint folder. It is default to be `'TheBloke/Llama-2-7B-Chat-AWQ'`.
 | 
			
		||||
- `--prompt PROMPT`: argument defining the prompt to be infered (with integrated prompt format for chat). It is default to be `'What is AI?'`.
 | 
			
		||||
- `--n-predict N_PREDICT`: argument defining the max number of tokens to predict. It is default to be `32`.
 | 
			
		||||
 | 
			
		||||
> **Note**: When loading the model in 4-bit, BigDL-LLM converts linear layers in the model into INT4 format. In theory, a *X*B model saved in 16-bit will requires approximately 2*X* GB of memory for loading, and ~0.5*X* GB memory for further inference.
 | 
			
		||||
>
 | 
			
		||||
> Please select the appropriate size of the Llama2 model based on the capabilities of your machine.
 | 
			
		||||
 | 
			
		||||
#### 2.1 Client
 | 
			
		||||
On client Windows machine, it is recommended to run directly with full utilization of all cores:
 | 
			
		||||
```powershell
 | 
			
		||||
python ./generate.py 
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
#### 2.2 Server
 | 
			
		||||
For optimal performance on server, it is recommended to set several environment variables (refer to [here](../README.md#best-known-configuration-on-linux) for more information), and run the example with all the physical cores of a single socket.
 | 
			
		||||
 | 
			
		||||
E.g. on Linux,
 | 
			
		||||
```bash
 | 
			
		||||
# set BigDL-Nano env variables
 | 
			
		||||
source bigdl-llm-init
 | 
			
		||||
 | 
			
		||||
# e.g. for a server with 48 cores per socket
 | 
			
		||||
export OMP_NUM_THREADS=48
 | 
			
		||||
numactl -C 0-47 -m 0 python ./generate.py
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
#### 2.3 Sample Output
 | 
			
		||||
#### ["TheBloke/Llama-2-7B-Chat-AWQ"](https://huggingface.co/TheBloke/Llama-2-7B-Chat-AWQ)
 | 
			
		||||
```log
 | 
			
		||||
Inference time: xxxx s
 | 
			
		||||
-------------------- Prompt --------------------
 | 
			
		||||
### HUMAN:
 | 
			
		||||
What is AI?
 | 
			
		||||
 | 
			
		||||
### RESPONSE:
 | 
			
		||||
 | 
			
		||||
-------------------- Output --------------------
 | 
			
		||||
### HUMAN:
 | 
			
		||||
What is AI?
 | 
			
		||||
 | 
			
		||||
### RESPONSE:
 | 
			
		||||
 | 
			
		||||
Artificial intelligence (AI) is the ability of machines to perform tasks that typically require human intelligence, such as learning, problem-solving, decision
 | 
			
		||||
```
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,71 @@
 | 
			
		|||
#
 | 
			
		||||
# Copyright 2016 The BigDL Authors.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
#
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import time
 | 
			
		||||
import argparse
 | 
			
		||||
 | 
			
		||||
from bigdl.llm.transformers import AutoModelForCausalLM
 | 
			
		||||
from transformers import LlamaTokenizer
 | 
			
		||||
 | 
			
		||||
# you could tune the prompt based on your own model,
 | 
			
		||||
# here the prompt tuning refers to https://huggingface.co/georgesung/llama2_7b_chat_uncensored#prompt-style
 | 
			
		||||
LLAMA2_PROMPT_FORMAT = """### HUMAN:
 | 
			
		||||
{prompt}
 | 
			
		||||
 | 
			
		||||
### RESPONSE:
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
    parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for Llama2 model')
 | 
			
		||||
    parser.add_argument('--repo-id-or-model-path', type=str, default="TheBloke/Llama-2-7B-Chat-AWQ",
 | 
			
		||||
                        help='The huggingface repo id'
 | 
			
		||||
                             ', or the path to the huggingface checkpoint folder')
 | 
			
		||||
    parser.add_argument('--prompt', type=str, default="What is AI?",
 | 
			
		||||
                        help='Prompt to infer')
 | 
			
		||||
    parser.add_argument('--n-predict', type=int, default=32,
 | 
			
		||||
                        help='Max tokens to predict')
 | 
			
		||||
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
    model_path = args.repo_id_or_model_path
 | 
			
		||||
 | 
			
		||||
    # Load model in 4 bit,
 | 
			
		||||
    # which convert the relevant layers in the model into INT4 format
 | 
			
		||||
    model = AutoModelForCausalLM.from_pretrained(model_path,
 | 
			
		||||
                                                 load_in_4bit=True,
 | 
			
		||||
                                                 trust_remote_code=True)
 | 
			
		||||
 | 
			
		||||
    # Load tokenizer
 | 
			
		||||
    tokenizer = LlamaTokenizer.from_pretrained(model_path, trust_remote_code=True)
 | 
			
		||||
    
 | 
			
		||||
    # Generate predicted tokens
 | 
			
		||||
    with torch.inference_mode():
 | 
			
		||||
        prompt = LLAMA2_PROMPT_FORMAT.format(prompt=args.prompt)
 | 
			
		||||
        input_ids = tokenizer.encode(prompt, return_tensors="pt")
 | 
			
		||||
        st = time.time()
 | 
			
		||||
        # if your selected model is capable of utilizing previous key/value attentions
 | 
			
		||||
        # to enhance decoding speed, but has `"use_cache": false` in its model config,
 | 
			
		||||
        # it is important to set `use_cache=True` explicitly in the `generate` function
 | 
			
		||||
        # to obtain optimal performance with BigDL-LLM INT4 optimizations
 | 
			
		||||
        output = model.generate(input_ids,
 | 
			
		||||
                                max_new_tokens=args.n_predict)
 | 
			
		||||
        end = time.time()
 | 
			
		||||
        output_str = tokenizer.decode(output[0], skip_special_tokens=True)
 | 
			
		||||
        print(f'Inference time: {end-st} s')
 | 
			
		||||
        print('-'*20, 'Prompt', '-'*20)
 | 
			
		||||
        print(prompt)
 | 
			
		||||
        print('-'*20, 'Output', '-'*20)
 | 
			
		||||
        print(output_str)
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,65 @@
 | 
			
		|||
# AWQ
 | 
			
		||||
This example shows how to directly run 4-bit AWQ models using BigDL-LLM on Intel GPU. For illustration purposes, we utilize the ["TheBloke/Llama-2-7B-Chat-AWQ"](https://huggingface.co/TheBloke/Llama-2-7B-Chat-AWQ) as a reference.
 | 
			
		||||
 | 
			
		||||
## 0. Requirements
 | 
			
		||||
To run these examples with BigDL-LLM, we have some recommended requirements for your machine, please refer to [here](../README.md#recommended-requirements) for more information.
 | 
			
		||||
 | 
			
		||||
## Example: Predict Tokens using `generate()` API
 | 
			
		||||
In the example [generate.py](./generate.py), we show a basic use case for a Llama2 model to predict the next N tokens using `generate()` API, with BigDL-LLM INT4 optimizations.
 | 
			
		||||
### 1. Install
 | 
			
		||||
We suggest using conda to manage environment:
 | 
			
		||||
```bash
 | 
			
		||||
conda create -n llm python=3.9
 | 
			
		||||
conda activate llm
 | 
			
		||||
 | 
			
		||||
pip install --pre --upgrade bigdl-llm[xpu] -f https://developer.intel.com/ipex-whl-stable-xpu
 | 
			
		||||
pip install transformers==4.35.0
 | 
			
		||||
pip install autoawq==0.1.6 --no-deps
 | 
			
		||||
pip install accelerate==0.24.1
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
### 2. Configures OneAPI environment variables
 | 
			
		||||
```bash
 | 
			
		||||
source /opt/intel/oneapi/setvars.sh
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
### 3. Run
 | 
			
		||||
 | 
			
		||||
For optimal performance on Arc, it is recommended to set several environment variables.
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
export USE_XETLA=OFF
 | 
			
		||||
export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
```
 | 
			
		||||
python ./generate.py --repo-id-or-model-path REPO_ID_OR_MODEL_PATH --prompt PROMPT --n-predict N_PREDICT
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
Arguments info:
 | 
			
		||||
- `--repo-id-or-model-path REPO_ID_OR_MODEL_PATH`: argument defining the huggingface repo id for the Llama2-awq model (e.g. `TheBloke/Llama-2-7B-Chat-AWQ`) to be downloaded, or the path to the huggingface checkpoint folder. It is default to be `'TheBloke/Llama-2-7B-Chat-AWQ'`.
 | 
			
		||||
- `--prompt PROMPT`: argument defining the prompt to be infered (with integrated prompt format for chat). It is default to be `'What is AI?'`.
 | 
			
		||||
- `--n-predict N_PREDICT`: argument defining the max number of tokens to predict. It is default to be `32`.
 | 
			
		||||
 | 
			
		||||
> **Note**: When loading the model in 4-bit, BigDL-LLM converts linear layers in the model into INT4 format. In theory, a *X*B model saved in 16-bit will requires approximately 2*X* GB of memory for loading, and ~0.5*X* GB memory for further inference.
 | 
			
		||||
>
 | 
			
		||||
> Please select the appropriate size of the Llama2 model based on the capabilities of your machine.
 | 
			
		||||
 | 
			
		||||
#### 2.3 Sample Output
 | 
			
		||||
#### ["TheBloke/Llama-2-7B-Chat-AWQ"](https://huggingface.co/TheBloke/Llama-2-7B-Chat-AWQ)
 | 
			
		||||
```log
 | 
			
		||||
Inference time: xxxx s
 | 
			
		||||
-------------------- Prompt --------------------
 | 
			
		||||
### HUMAN:
 | 
			
		||||
What is AI?
 | 
			
		||||
 | 
			
		||||
### RESPONSE:
 | 
			
		||||
 | 
			
		||||
-------------------- Output --------------------
 | 
			
		||||
### HUMAN:
 | 
			
		||||
What is AI?
 | 
			
		||||
 | 
			
		||||
### RESPONSE:
 | 
			
		||||
 | 
			
		||||
Artificial intelligence (AI) is the ability of machines to perform tasks that typically require human intelligence, such as learning, problem-solving, decision
 | 
			
		||||
```
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,71 @@
 | 
			
		|||
#
 | 
			
		||||
# Copyright 2016 The BigDL Authors.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
#
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import time
 | 
			
		||||
import argparse
 | 
			
		||||
import intel_extension_for_pytorch as ipex
 | 
			
		||||
from bigdl.llm.transformers import AutoModelForCausalLM
 | 
			
		||||
from transformers import LlamaTokenizer
 | 
			
		||||
 | 
			
		||||
# you could tune the prompt based on your own model,
 | 
			
		||||
# here the prompt tuning refers to https://huggingface.co/georgesung/llama2_7b_chat_uncensored#prompt-style
 | 
			
		||||
LLAMA2_PROMPT_FORMAT = """### HUMAN:
 | 
			
		||||
{prompt}
 | 
			
		||||
 | 
			
		||||
### RESPONSE:
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
    parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for Llama2 model')
 | 
			
		||||
    parser.add_argument('--repo-id-or-model-path', type=str, default="TheBloke/Llama-2-7B-Chat-AWQ",
 | 
			
		||||
                        help='The huggingface repo id'
 | 
			
		||||
                             ', or the path to the huggingface checkpoint folder')
 | 
			
		||||
    parser.add_argument('--prompt', type=str, default="What is AI?",
 | 
			
		||||
                        help='Prompt to infer')
 | 
			
		||||
    parser.add_argument('--n-predict', type=int, default=32,
 | 
			
		||||
                        help='Max tokens to predict')
 | 
			
		||||
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
    model_path = args.repo_id_or_model_path
 | 
			
		||||
 | 
			
		||||
    # Load model in 4 bit,
 | 
			
		||||
    # which convert the relevant layers in the model into INT4 format
 | 
			
		||||
    model = AutoModelForCausalLM.from_pretrained(model_path,
 | 
			
		||||
                                                 load_in_4bit=True,
 | 
			
		||||
                                                 trust_remote_code=True,).to("xpu")
 | 
			
		||||
 | 
			
		||||
    # Load tokenizer
 | 
			
		||||
    tokenizer = LlamaTokenizer.from_pretrained(model_path, trust_remote_code=True)
 | 
			
		||||
    
 | 
			
		||||
    # Generate predicted tokens
 | 
			
		||||
    with torch.inference_mode():
 | 
			
		||||
        prompt = LLAMA2_PROMPT_FORMAT.format(prompt=args.prompt)
 | 
			
		||||
        input_ids = tokenizer.encode(prompt, return_tensors="pt").to("xpu")
 | 
			
		||||
        st = time.time()
 | 
			
		||||
        # if your selected model is capable of utilizing previous key/value attentions
 | 
			
		||||
        # to enhance decoding speed, but has `"use_cache": false` in its model config,
 | 
			
		||||
        # it is important to set `use_cache=True` explicitly in the `generate` function
 | 
			
		||||
        # to obtain optimal performance with BigDL-LLM INT4 optimizations
 | 
			
		||||
        output = model.generate(input_ids,
 | 
			
		||||
                                max_new_tokens=args.n_predict)
 | 
			
		||||
        end = time.time()
 | 
			
		||||
        output_str = tokenizer.decode(output[0], skip_special_tokens=True)
 | 
			
		||||
        print(f'Inference time: {end-st} s')
 | 
			
		||||
        print('-'*20, 'Prompt', '-'*20)
 | 
			
		||||
        print(prompt)
 | 
			
		||||
        print('-'*20, 'Output', '-'*20)
 | 
			
		||||
        print(output_str)
 | 
			
		||||
							
								
								
									
										21
									
								
								python/llm/src/bigdl/llm/transformers/awq/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										21
									
								
								python/llm/src/bigdl/llm/transformers/awq/__init__.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,21 @@
 | 
			
		|||
#
 | 
			
		||||
# Copyright 2016 The BigDL Authors.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
#
 | 
			
		||||
 | 
			
		||||
# This would makes sure Python is aware there is more than one sub-package within bigdl,
 | 
			
		||||
# physically located elsewhere.
 | 
			
		||||
# Otherwise there would be module not found error in non-pip's setting as Python would
 | 
			
		||||
# only search the first bigdl package and end up finding only one sub-package.
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										54
									
								
								python/llm/src/bigdl/llm/transformers/awq/act.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										54
									
								
								python/llm/src/bigdl/llm/transformers/awq/act.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,54 @@
 | 
			
		|||
#
 | 
			
		||||
# Copyright 2016 The BigDL Authors.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
#
 | 
			
		||||
# ===========================================================================
 | 
			
		||||
#
 | 
			
		||||
# This file is copied from
 | 
			
		||||
# https://github.com/casper-hansen/AutoAWQ/blob/main/awq/modules/act.py
 | 
			
		||||
#
 | 
			
		||||
# MIT License
 | 
			
		||||
#
 | 
			
		||||
# Copyright (c) 2023 MIT HAN Lab
 | 
			
		||||
#
 | 
			
		||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
 | 
			
		||||
# of this software and associated documentation files (the "Software"), to deal
 | 
			
		||||
# in the Software without restriction, including without limitation the rights
 | 
			
		||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 | 
			
		||||
# copies of the Software, and to permit persons to whom the Software is
 | 
			
		||||
# furnished to do so, subject to the following conditions:
 | 
			
		||||
#
 | 
			
		||||
# The above copyright notice and this permission notice shall be included in all
 | 
			
		||||
# copies or substantial portions of the Software.
 | 
			
		||||
#
 | 
			
		||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 | 
			
		||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 | 
			
		||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 | 
			
		||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 | 
			
		||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 | 
			
		||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
 | 
			
		||||
# SOFTWARE.
 | 
			
		||||
#
 | 
			
		||||
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ScaledActivation(nn.Module):
 | 
			
		||||
    def __init__(self, module, scales):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.act = module
 | 
			
		||||
        self.scales = nn.Parameter(scales.data)
 | 
			
		||||
 | 
			
		||||
    def forward(self, x):
 | 
			
		||||
        return self.act(x) / self.scales.view(1, 1, -1).to(x.device)
 | 
			
		||||
							
								
								
									
										223
									
								
								python/llm/src/bigdl/llm/transformers/awq/awq.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										223
									
								
								python/llm/src/bigdl/llm/transformers/awq/awq.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,223 @@
 | 
			
		|||
#
 | 
			
		||||
# Copyright 2016 The BigDL Authors.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
#
 | 
			
		||||
# ===========================================================================
 | 
			
		||||
#
 | 
			
		||||
# This file is adapted from
 | 
			
		||||
# https://github.com/casper-hansen/AutoAWQ/blob/main/awq/models/base.py#L147
 | 
			
		||||
#  and https://github.com/mit-han-lab/llm-awq/blob/main/awq/quantize/quantizer.py
 | 
			
		||||
#
 | 
			
		||||
# MIT License
 | 
			
		||||
#
 | 
			
		||||
# Copyright (c) 2023 MIT HAN Lab
 | 
			
		||||
#
 | 
			
		||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
 | 
			
		||||
# of this software and associated documentation files (the "Software"), to deal
 | 
			
		||||
# in the Software without restriction, including without limitation the rights
 | 
			
		||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 | 
			
		||||
# copies of the Software, and to permit persons to whom the Software is
 | 
			
		||||
# furnished to do so, subject to the following conditions:
 | 
			
		||||
#
 | 
			
		||||
# The above copyright notice and this permission notice shall be included in all
 | 
			
		||||
# copies or substantial portions of the Software.
 | 
			
		||||
#
 | 
			
		||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 | 
			
		||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 | 
			
		||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 | 
			
		||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 | 
			
		||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 | 
			
		||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
 | 
			
		||||
# SOFTWARE.
 | 
			
		||||
#
 | 
			
		||||
 | 
			
		||||
import gc
 | 
			
		||||
import os
 | 
			
		||||
from tqdm import tqdm
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
from transformers.models.bloom.modeling_bloom import BloomForCausalLM
 | 
			
		||||
from transformers.models.opt.modeling_opt import OPTForCausalLM
 | 
			
		||||
from transformers.models.llama.modeling_llama import LlamaForCausalLM
 | 
			
		||||
from transformers.models.bloom.modeling_bloom import BloomBlock
 | 
			
		||||
from transformers import AwqConfig, AutoConfig
 | 
			
		||||
from bigdl.llm.transformers.awq.linear import WQLinear_GEMM, WQLinear_GEMV
 | 
			
		||||
from huggingface_hub import snapshot_download
 | 
			
		||||
from bigdl.llm.utils.common import invalidInputError
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
layer_type_dict = {
 | 
			
		||||
    "mpt": "MPTBlock",
 | 
			
		||||
    "llama": "LlamaDecoderLayer",
 | 
			
		||||
    "opt": "OPTDecoderLayer",
 | 
			
		||||
    "RefinedWeb": "FalconDecoderLayer",
 | 
			
		||||
    "RefinedWebModel": "FalconDecoderLayer",
 | 
			
		||||
    "falcon": "FalconDecoderLayer",
 | 
			
		||||
    "bloom": "BloomBlock",
 | 
			
		||||
    "gptj": "GPTJBlock",
 | 
			
		||||
    "gpt_bigcode": "GPTBigCodeBlock",
 | 
			
		||||
    "mistral": "MistralDecoderLayer",
 | 
			
		||||
    "gpt_neox": "GPTNeoXDecoderLayer",
 | 
			
		||||
    "aquila": "AquilaDecoderLayer",
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def set_op_by_name(layer, name, new_module):
 | 
			
		||||
    levels = name.split('.')
 | 
			
		||||
    if len(levels) > 1:
 | 
			
		||||
        mod_ = layer
 | 
			
		||||
        for l_idx in range(len(levels)-1):
 | 
			
		||||
            if levels[l_idx].isdigit():
 | 
			
		||||
                mod_ = mod_[int(levels[l_idx])]
 | 
			
		||||
            else:
 | 
			
		||||
                mod_ = getattr(mod_, levels[l_idx])
 | 
			
		||||
        setattr(mod_, levels[-1], new_module)
 | 
			
		||||
    else:
 | 
			
		||||
        setattr(layer, name, new_module)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _load_config(model_path, model_filename, safetensors=False,
 | 
			
		||||
                 trust_remote_code=True, max_new_tokens=4096):
 | 
			
		||||
    # [STEP 1] Download model if path is not a directory
 | 
			
		||||
    if not os.path.isdir(model_path):
 | 
			
		||||
        ignore_patterns = ["*msgpack*", "*h5*"]
 | 
			
		||||
        if safetensors:
 | 
			
		||||
            ignore_patterns.extend(["*.pt*", "*.bin*"])
 | 
			
		||||
        else:
 | 
			
		||||
            ignore_patterns.append("*.safetensors*")
 | 
			
		||||
 | 
			
		||||
        model_path = snapshot_download(model_path, ignore_patterns=ignore_patterns)
 | 
			
		||||
 | 
			
		||||
    if model_filename != '':
 | 
			
		||||
        model_weights_path = model_path + f'/{model_filename}'
 | 
			
		||||
    else:
 | 
			
		||||
        model_weights_path = model_path
 | 
			
		||||
 | 
			
		||||
    # Load model config and set max generation length
 | 
			
		||||
    max_new_tokens = 2048 if max_new_tokens is None else max_new_tokens
 | 
			
		||||
    config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code)
 | 
			
		||||
    config.max_new_tokens = max_new_tokens
 | 
			
		||||
 | 
			
		||||
    return model_weights_path, config
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_named_linears(module):
 | 
			
		||||
    return {name: m for name, m in module.named_modules() if isinstance(m, nn.Linear)}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_blocks(model):
 | 
			
		||||
    if model.__class__.__name__ == 'LlamaForCausalLM':
 | 
			
		||||
        layers = model.model.layers
 | 
			
		||||
    elif isinstance(model, OPTForCausalLM):
 | 
			
		||||
        layers = model.model.decoder.layers
 | 
			
		||||
    elif isinstance(model, BloomForCausalLM):
 | 
			
		||||
        layers = model.transformer.h
 | 
			
		||||
    elif "mpt" in str(model.__class__).lower():
 | 
			
		||||
        layers = model.transformer.blocks
 | 
			
		||||
    elif "falcon" in str(model.__class__).lower():
 | 
			
		||||
        layers = model.transformer.h
 | 
			
		||||
    elif "bigcode" in str(model.__class__).lower():
 | 
			
		||||
        layers = model.transformer.h
 | 
			
		||||
    elif "neox" in str(model.__class__).lower():
 | 
			
		||||
        layers = model.gpt_neox.layers
 | 
			
		||||
    else:
 | 
			
		||||
        invalidInputError(False, f"Model type {type(model)} isn't supported.")
 | 
			
		||||
    return layers
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_layer_type(config):
 | 
			
		||||
    if config.model_type not in layer_type_dict.keys():
 | 
			
		||||
        invalidInputError(False, f"{config.model_type} isn't supported yet.")
 | 
			
		||||
    return layer_type_dict[config.model_type]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def scale_activations(module):
 | 
			
		||||
    from bigdl.llm.transformers.awq.act import ScaledActivation
 | 
			
		||||
    param = next(module.parameters())
 | 
			
		||||
    dtype = param.dtype
 | 
			
		||||
    device = param.device
 | 
			
		||||
    if isinstance(module, BloomBlock):
 | 
			
		||||
        if isinstance(module.mlp.gelu_impl, ScaledActivation):
 | 
			
		||||
            return
 | 
			
		||||
        c = module.mlp.dense_h_to_4h.out_features
 | 
			
		||||
        act = ScaledActivation(
 | 
			
		||||
            module.mlp.gelu_impl,
 | 
			
		||||
            torch.ones(c, dtype=dtype, device=device)
 | 
			
		||||
        )
 | 
			
		||||
        set_op_by_name(module, "mlp.gelu_impl", act)
 | 
			
		||||
    elif 'mptblock' in str(module.__class__.__name__).lower():
 | 
			
		||||
        if isinstance(module.ffn.act, ScaledActivation):
 | 
			
		||||
            return
 | 
			
		||||
        c = module.ffn.up_proj.out_features
 | 
			
		||||
        act = ScaledActivation(
 | 
			
		||||
            module.ffn.act,
 | 
			
		||||
            torch.ones(c, dtype=dtype, device=device)
 | 
			
		||||
        )
 | 
			
		||||
        set_op_by_name(module, "ffn.act", act)
 | 
			
		||||
    elif 'falcon' in str(module.__class__).lower():
 | 
			
		||||
        if isinstance(module.mlp.act, ScaledActivation):
 | 
			
		||||
            return
 | 
			
		||||
        c = module.mlp.dense_h_to_4h.out_features
 | 
			
		||||
        act = ScaledActivation(
 | 
			
		||||
            module.mlp.act,
 | 
			
		||||
            torch.ones(c, dtype=dtype, device=device)
 | 
			
		||||
        )
 | 
			
		||||
        set_op_by_name(module, "mlp.act", act)
 | 
			
		||||
    elif 'bigcode' in str(module.__class__).lower():
 | 
			
		||||
        if isinstance(module.mlp.act, ScaledActivation):
 | 
			
		||||
            return
 | 
			
		||||
        c = module.mlp.c_proj.out_features
 | 
			
		||||
        act = ScaledActivation(
 | 
			
		||||
            module.mlp.act,
 | 
			
		||||
            torch.ones(c, dtype=dtype, device=device)
 | 
			
		||||
        )
 | 
			
		||||
        set_op_by_name(module, "mlp.act", act)
 | 
			
		||||
    elif 'neox' in str(module.__class__).lower():
 | 
			
		||||
        if isinstance(module.mlp.act, ScaledActivation):
 | 
			
		||||
            return
 | 
			
		||||
        c = module.mlp.dense_h_to_4h.out_features
 | 
			
		||||
        act = ScaledActivation(
 | 
			
		||||
            module.mlp.act,
 | 
			
		||||
            torch.ones(c, dtype=dtype, device=device)
 | 
			
		||||
        )
 | 
			
		||||
        set_op_by_name(module, "mlp.act", act)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _replace_with_awq_layers(model, awq_config: AwqConfig):
 | 
			
		||||
    layers = get_blocks(model)
 | 
			
		||||
 | 
			
		||||
    for i in tqdm(range(len(layers)), desc="Replacing layers..."):
 | 
			
		||||
        layer = layers[i]
 | 
			
		||||
 | 
			
		||||
        # Get every linear layer in a block
 | 
			
		||||
        named_linears = get_named_linears(layer)
 | 
			
		||||
 | 
			
		||||
        # Replace activation functions
 | 
			
		||||
        scale_activations(layer)
 | 
			
		||||
 | 
			
		||||
        # Replace nn.Linear with WQLinear
 | 
			
		||||
        for name, module in named_linears.items():
 | 
			
		||||
            if awq_config.version == 'gemm':
 | 
			
		||||
                q_linear_module = WQLinear_GEMM
 | 
			
		||||
            elif awq_config.version == 'gemv':
 | 
			
		||||
                q_linear_module = WQLinear_GEMV
 | 
			
		||||
 | 
			
		||||
            q_linear = q_linear_module.from_linear(module,
 | 
			
		||||
                                                   awq_config.bits,
 | 
			
		||||
                                                   awq_config.group_size,
 | 
			
		||||
                                                   True)
 | 
			
		||||
            q_linear.to(next(layer.parameters()).device)
 | 
			
		||||
            set_op_by_name(layer, name, q_linear)
 | 
			
		||||
 | 
			
		||||
        gc.collect()
 | 
			
		||||
							
								
								
									
										99
									
								
								python/llm/src/bigdl/llm/transformers/awq/awq_config.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										99
									
								
								python/llm/src/bigdl/llm/transformers/awq/awq_config.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,99 @@
 | 
			
		|||
#
 | 
			
		||||
# Copyright 2016 The BigDL Authors.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
#
 | 
			
		||||
# ===========================================================================
 | 
			
		||||
#
 | 
			
		||||
# This file is copied from
 | 
			
		||||
# https://github.com/huggingface/transformers
 | 
			
		||||
#
 | 
			
		||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
#
 | 
			
		||||
 | 
			
		||||
from dataclasses import dataclass
 | 
			
		||||
from bigdl.llm.utils.common import invalidInputError
 | 
			
		||||
from transformers.utils.quantization_config import QuantizationConfigMixin
 | 
			
		||||
from transformers.utils.quantization_config import AwqBackendPackingMethod,\
 | 
			
		||||
    AWQLinearVersion, QuantizationMethod
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class AwqConfig(QuantizationConfigMixin):
 | 
			
		||||
    """
 | 
			
		||||
    This is a wrapper class about all possible attributes and features that you can
 | 
			
		||||
     play with a model that has been loaded using `auto-awq` library awq quantization
 | 
			
		||||
     relying on auto_awq backend.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        bits (`int`, *optional*, defaults to 4):
 | 
			
		||||
            The number of bits to quantize to.
 | 
			
		||||
        group_size (`int`, *optional*, defaults to 128):
 | 
			
		||||
            The group size to use for quantization.
 | 
			
		||||
            Recommended value is 128 and -1 uses per-column quantization.
 | 
			
		||||
        zero_point (`bool`, *optional*, defaults to `True`):
 | 
			
		||||
            Whether to use zero point quantization.
 | 
			
		||||
        version (`AWQLinearVersion`, *optional*, defaults to
 | 
			
		||||
        `AWQLinearVersion.GEMM`):
 | 
			
		||||
            The version of the quantization algorithm to use.
 | 
			
		||||
            GEMM is better for big batch_size (e.g. >= 8) otherwise,
 | 
			
		||||
            GEMV is better (e.g. < 8 )
 | 
			
		||||
        backend (`AwqBackendPackingMethod`, *optional*, defaults to
 | 
			
		||||
        `AwqBackendPackingMethod.AUTOAWQ`):
 | 
			
		||||
            The quantization backend. Some models might be quantized using `llm-awq` backend.
 | 
			
		||||
            This is useful for users that quantize their own models using `llm-awq` library.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        bits: int = 4,
 | 
			
		||||
        group_size: int = 128,
 | 
			
		||||
        zero_point: bool = True,
 | 
			
		||||
        version: AWQLinearVersion = AWQLinearVersion.GEMM,
 | 
			
		||||
        backend: AwqBackendPackingMethod = AwqBackendPackingMethod.AUTOAWQ,
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    ):
 | 
			
		||||
        self.quant_method = QuantizationMethod.AWQ
 | 
			
		||||
 | 
			
		||||
        self.bits = bits
 | 
			
		||||
        self.group_size = group_size
 | 
			
		||||
        self.zero_point = zero_point
 | 
			
		||||
        self.version = version
 | 
			
		||||
        self.backend = backend
 | 
			
		||||
 | 
			
		||||
        self.post_init()
 | 
			
		||||
 | 
			
		||||
    def post_init(self):
 | 
			
		||||
        r"""
 | 
			
		||||
        Safety checker that arguments are correct
 | 
			
		||||
        """
 | 
			
		||||
        invalidInputError(self.backend == AwqBackendPackingMethod.AUTOAWQ,
 | 
			
		||||
                          "Only supported quantization backends in "
 | 
			
		||||
                          f"{AwqBackendPackingMethod.AUTOAWQ} - "
 | 
			
		||||
                          f"not recognized backend {self.backend}")
 | 
			
		||||
 | 
			
		||||
        invalidInputError(self.version in [AWQLinearVersion.GEMM, AWQLinearVersion.GEMV],
 | 
			
		||||
                          "Only supported versions are in [AWQLinearVersion.GEMM,"
 | 
			
		||||
                          f"AWQLinearVersion.GEMV] - not recognized version {self.version}")
 | 
			
		||||
							
								
								
									
										284
									
								
								python/llm/src/bigdl/llm/transformers/awq/linear.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										284
									
								
								python/llm/src/bigdl/llm/transformers/awq/linear.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,284 @@
 | 
			
		|||
#
 | 
			
		||||
# Copyright 2016 The BigDL Authors.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
#
 | 
			
		||||
# ===========================================================================
 | 
			
		||||
#
 | 
			
		||||
# This file is adapted from
 | 
			
		||||
# https://github.com/casper-hansen/AutoAWQ/blob/main/awq/modules/linear.py
 | 
			
		||||
#
 | 
			
		||||
# MIT License
 | 
			
		||||
#
 | 
			
		||||
# Copyright (c) 2023 MIT HAN Lab
 | 
			
		||||
#
 | 
			
		||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
 | 
			
		||||
# of this software and associated documentation files (the "Software"), to deal
 | 
			
		||||
# in the Software without restriction, including without limitation the rights
 | 
			
		||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 | 
			
		||||
# copies of the Software, and to permit persons to whom the Software is
 | 
			
		||||
# furnished to do so, subject to the following conditions:
 | 
			
		||||
#
 | 
			
		||||
# The above copyright notice and this permission notice shall be included in all
 | 
			
		||||
# copies or substantial portions of the Software.
 | 
			
		||||
#
 | 
			
		||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 | 
			
		||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 | 
			
		||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 | 
			
		||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 | 
			
		||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 | 
			
		||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
 | 
			
		||||
# SOFTWARE.
 | 
			
		||||
#
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
from bigdl.llm.utils.common import invalidOperationError, invalidInputError
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def make_divisible(c, divisor):
 | 
			
		||||
    return (c + divisor - 1) // divisor
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def calculate_zeros_width(in_features, group_size=128, pack_num=8):
 | 
			
		||||
    if group_size >= 128:
 | 
			
		||||
        size_multiplier = 1
 | 
			
		||||
    elif group_size == 64:
 | 
			
		||||
        size_multiplier = 2
 | 
			
		||||
    elif group_size == 32:
 | 
			
		||||
        size_multiplier = 4
 | 
			
		||||
    else:
 | 
			
		||||
        invalidOperationError(False,
 | 
			
		||||
                              f"Not implemented group size {group_size}.")
 | 
			
		||||
 | 
			
		||||
    base_width = make_divisible(in_features // group_size, pack_num)
 | 
			
		||||
    base_width = make_divisible(base_width, size_multiplier) * size_multiplier
 | 
			
		||||
    return base_width
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class WQLinear_GEMM(nn.Module):
 | 
			
		||||
    def __init__(self, bits, group_size, in_features, out_features, bias, dev):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
        invalidOperationError(bits == 4, "Only 4-bit are supported for now.")
 | 
			
		||||
 | 
			
		||||
        self.in_features = in_features
 | 
			
		||||
        self.out_features = out_features
 | 
			
		||||
        self.bits = bits
 | 
			
		||||
        self.group_size = group_size if group_size != -1 else in_features
 | 
			
		||||
 | 
			
		||||
        self.wf = (torch.tensor([0, 4, 1, 5, 2, 6, 3, 7],
 | 
			
		||||
                                dtype=torch.int32) * self.bits).unsqueeze(0)
 | 
			
		||||
 | 
			
		||||
        # quick sanity check (make sure aligment)
 | 
			
		||||
        invalidInputError(self.in_features % self.group_size == 0,
 | 
			
		||||
                          f"Invalid in_features number {self.in_features}.")
 | 
			
		||||
        invalidInputError(out_features % (32 // self.bits) == 0,
 | 
			
		||||
                          f"Invalid out_features number {out_features}.")
 | 
			
		||||
 | 
			
		||||
        self.register_buffer('qweight',
 | 
			
		||||
                             torch.zeros((in_features,
 | 
			
		||||
                                          out_features // (32 // self.bits)),
 | 
			
		||||
                                         dtype=torch.int32, device=dev))
 | 
			
		||||
        self.register_buffer('qzeros',
 | 
			
		||||
                             torch.zeros((in_features // self.group_size,
 | 
			
		||||
                                          out_features // (32 // self.bits)),
 | 
			
		||||
                                         dtype=torch.int32, device=dev))
 | 
			
		||||
        self.register_buffer('scales',
 | 
			
		||||
                             torch.zeros((in_features // self.group_size, out_features),
 | 
			
		||||
                                         dtype=torch.float16, device=dev))
 | 
			
		||||
        if bias:
 | 
			
		||||
            self.register_buffer('bias', torch.zeros((out_features), dtype=torch.float16,
 | 
			
		||||
                                                     device=dev))
 | 
			
		||||
        else:
 | 
			
		||||
            self.bias = None
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def from_linear(cls, linear, bits, group_size, init_only=False, scales=None, zeros=None):
 | 
			
		||||
        awq_linear = cls(bits, group_size, linear.in_features, linear.out_features,
 | 
			
		||||
                         linear.bias is not None, linear.weight.device)
 | 
			
		||||
        if init_only:  # just prepare for loading sd
 | 
			
		||||
            return awq_linear
 | 
			
		||||
 | 
			
		||||
        # need scales and zeros info for real quantization
 | 
			
		||||
        invalidInputError(scales is not None and zeros is not None,
 | 
			
		||||
                          "Scales and zeros should not be None.")
 | 
			
		||||
        scale_zeros = zeros * scales
 | 
			
		||||
 | 
			
		||||
        awq_linear.scales = scales.clone().half()
 | 
			
		||||
        if linear.bias is not None:
 | 
			
		||||
            awq_linear.bias = linear.bias.clone().half()
 | 
			
		||||
 | 
			
		||||
        pack_num = 32 // awq_linear.bits
 | 
			
		||||
 | 
			
		||||
        intweight = []
 | 
			
		||||
        for idx in range(awq_linear.in_features):
 | 
			
		||||
            intweight.append(
 | 
			
		||||
                torch.round((linear.weight.data[:, idx] +
 | 
			
		||||
                             scale_zeros[idx // group_size]) /
 | 
			
		||||
                            awq_linear.scales[idx // group_size]).to(torch.int)[:, None])
 | 
			
		||||
        intweight = torch.cat(intweight, dim=1)
 | 
			
		||||
        intweight = intweight.t().contiguous()
 | 
			
		||||
        intweight = intweight.to(dtype=torch.int32)
 | 
			
		||||
        qweight = torch.zeros((intweight.shape[0],
 | 
			
		||||
                               intweight.shape[1] // (32 // awq_linear.bits)),
 | 
			
		||||
                              dtype=torch.int32, device=intweight.device)
 | 
			
		||||
 | 
			
		||||
        torch.set_printoptions(threshold=10_000)
 | 
			
		||||
        print(intweight)
 | 
			
		||||
 | 
			
		||||
        for col in range(intweight.shape[1] // pack_num):
 | 
			
		||||
            if awq_linear.bits == 4:
 | 
			
		||||
                order_map = [0, 2, 4, 6, 1, 3, 5, 7]
 | 
			
		||||
            else:
 | 
			
		||||
                invalidOperationError(False, "Only 4-bit are supported for now.")
 | 
			
		||||
            for i in range(pack_num):
 | 
			
		||||
                qweight_col = intweight[:, col * pack_num + order_map[i]]
 | 
			
		||||
                qweight[:, col] |= qweight_col << (i * awq_linear.bits)
 | 
			
		||||
        awq_linear.qweight = qweight
 | 
			
		||||
 | 
			
		||||
        zeros = zeros.to(dtype=torch.int32)
 | 
			
		||||
        qzeros = torch.zeros((zeros.shape[0], zeros.shape[1] // (32 // awq_linear.bits)),
 | 
			
		||||
                             dtype=torch.int32, device=zeros.device)
 | 
			
		||||
 | 
			
		||||
        for col in range(zeros.shape[1] // pack_num):
 | 
			
		||||
            if awq_linear.bits == 4:
 | 
			
		||||
                order_map = [0, 2, 4, 6, 1, 3, 5, 7]
 | 
			
		||||
            else:
 | 
			
		||||
                invalidOperationError(False, "Only 4-bit are supported for now.")
 | 
			
		||||
            for i in range(pack_num):
 | 
			
		||||
                qzero_col = zeros[:, col * pack_num + order_map[i]]
 | 
			
		||||
                qzeros[:, col] |= qzero_col << (i * awq_linear.bits)
 | 
			
		||||
        awq_linear.qzeros = qzeros
 | 
			
		||||
 | 
			
		||||
        return awq_linear
 | 
			
		||||
 | 
			
		||||
    @torch.no_grad()
 | 
			
		||||
    def forward(self, x):
 | 
			
		||||
        invalidOperationError(False, "Bigdl-llm does not support inference awq models directly.")
 | 
			
		||||
 | 
			
		||||
    def extra_repr(self) -> str:
 | 
			
		||||
        return 'in_features={}, out_features={}, bias={}, bits={}, group_size={}'.format(
 | 
			
		||||
            self.in_features, self.out_features, self.bias is not None, self.bits, self.group_size
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class WQLinear_GEMV(nn.Module):
 | 
			
		||||
    def __init__(self, bits, group_size, in_features, out_features, bias, dev):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
        invalidOperationError(bits == 4, "Only 4-bit are supported for now.")
 | 
			
		||||
 | 
			
		||||
        self.in_features = in_features
 | 
			
		||||
        self.out_features = out_features
 | 
			
		||||
        self.bits = bits
 | 
			
		||||
        self.group_size = group_size if group_size != -1 else in_features
 | 
			
		||||
        self.split_k_iters = 8
 | 
			
		||||
 | 
			
		||||
        # quick sanity check (make sure aligment)
 | 
			
		||||
        invalidInputError(self.in_features % self.group_size == 0,
 | 
			
		||||
                          f"Invalid in_features number {self.in_features}.")
 | 
			
		||||
        invalidInputError(out_features % (32 // self.bits) == 0,
 | 
			
		||||
                          f"Invalid out_features number {out_features}.")
 | 
			
		||||
        pack_num = (32 // self.bits)
 | 
			
		||||
 | 
			
		||||
        self.register_buffer('qweight',
 | 
			
		||||
                             torch.zeros((out_features, in_features // pack_num),
 | 
			
		||||
                                         dtype=torch.int32, device=dev))
 | 
			
		||||
        self.register_buffer('qzeros',
 | 
			
		||||
                             torch.zeros((out_features,
 | 
			
		||||
                                          calculate_zeros_width(in_features,
 | 
			
		||||
                                                                self.group_size)),
 | 
			
		||||
                                         dtype=torch.int32, device=dev))
 | 
			
		||||
        self.register_buffer('scales',
 | 
			
		||||
                             torch.zeros((out_features,
 | 
			
		||||
                                          calculate_zeros_width(in_features, self.group_size)
 | 
			
		||||
                                          * pack_num), dtype=torch.float16, device=dev))
 | 
			
		||||
        if bias:
 | 
			
		||||
            self.register_buffer('bias', torch.zeros((out_features),
 | 
			
		||||
                                                     dtype=torch.float16, device=dev))
 | 
			
		||||
        else:
 | 
			
		||||
            self.bias = None
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def from_linear(cls, linear, bits, group_size, init_only=False, scales=None, zeros=None):
 | 
			
		||||
        awq_linear = cls(bits, group_size, linear.in_features, linear.out_features,
 | 
			
		||||
                         linear.bias is not None, linear.weight.device)
 | 
			
		||||
        if init_only:  # just prepare for loading sd
 | 
			
		||||
            return awq_linear
 | 
			
		||||
 | 
			
		||||
        # need scales and zeros info for real quantization
 | 
			
		||||
        invalidInputError(scales is not None and zeros is not None,
 | 
			
		||||
                          "Scales and zeros should not be None.")
 | 
			
		||||
        scale_zeros = zeros * scales
 | 
			
		||||
 | 
			
		||||
        pack_num = 32 // awq_linear.bits
 | 
			
		||||
        qscales = torch.zeros(
 | 
			
		||||
            (scales.shape[0], calculate_zeros_width(linear.in_features, group_size) * pack_num),
 | 
			
		||||
            dtype=torch.float16,
 | 
			
		||||
            device=scales.device
 | 
			
		||||
        )
 | 
			
		||||
        qscales[:, :scales.shape[1]] = scales
 | 
			
		||||
        awq_linear.scales = qscales
 | 
			
		||||
        if linear.bias is not None:
 | 
			
		||||
            awq_linear.bias = linear.bias.clone().half()
 | 
			
		||||
 | 
			
		||||
        intweight = []
 | 
			
		||||
        for idx in range(awq_linear.in_features):
 | 
			
		||||
            intweight.append(
 | 
			
		||||
                torch.round((linear.weight.data[:, idx] +
 | 
			
		||||
                             scale_zeros[:, idx // group_size]) /
 | 
			
		||||
                            awq_linear.scales[:, idx // group_size]).to(torch.int)[:, None])
 | 
			
		||||
        intweight = torch.cat(intweight, dim=1)
 | 
			
		||||
        intweight = intweight.to(dtype=torch.int32)
 | 
			
		||||
        qweight = torch.zeros((intweight.shape[0], intweight.shape[1] // 32 * awq_linear.bits),
 | 
			
		||||
                              dtype=torch.int32, device=intweight.device)
 | 
			
		||||
 | 
			
		||||
        for col in range(intweight.shape[1] // pack_num):
 | 
			
		||||
            if awq_linear.bits == 4:
 | 
			
		||||
                order_map = [0, 1, 2, 3, 4, 5, 6, 7]
 | 
			
		||||
            else:
 | 
			
		||||
                invalidOperationError(False, "Only 4-bit are supported for now.")
 | 
			
		||||
            for i in range(pack_num):
 | 
			
		||||
                qweight_col = intweight[:, col * pack_num + order_map[i]]
 | 
			
		||||
                qweight[:, col] |= qweight_col << (i * awq_linear.bits)
 | 
			
		||||
        awq_linear.qweight = qweight
 | 
			
		||||
 | 
			
		||||
        zeros = zeros.to(dtype=torch.int32)
 | 
			
		||||
        qzeros = torch.zeros(
 | 
			
		||||
            (zeros.shape[0], calculate_zeros_width(linear.in_features, group_size)),
 | 
			
		||||
            dtype=torch.int32,
 | 
			
		||||
            device=zeros.device,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        for col in range((zeros.shape[1] + pack_num - 1) // pack_num):
 | 
			
		||||
            if awq_linear.bits == 4:
 | 
			
		||||
                order_map = [0, 1, 2, 3, 4, 5, 6, 7]
 | 
			
		||||
            else:
 | 
			
		||||
                invalidOperationError(False, "Only 4-bit are supported for now.")
 | 
			
		||||
            for i in range(pack_num):
 | 
			
		||||
                if col * pack_num + order_map[i] >= zeros.shape[1]:
 | 
			
		||||
                    continue
 | 
			
		||||
                qzero_col = zeros[:, col * pack_num + order_map[i]]
 | 
			
		||||
                qzeros[:, col] |= qzero_col << (i * awq_linear.bits)
 | 
			
		||||
        awq_linear.qzeros = qzeros
 | 
			
		||||
        return awq_linear
 | 
			
		||||
 | 
			
		||||
    @torch.no_grad()
 | 
			
		||||
    def forward(self, x):
 | 
			
		||||
        invalidOperationError(False, "Bigdl-llm does not support inference awq models directly.")
 | 
			
		||||
 | 
			
		||||
    def extra_repr(self) -> str:
 | 
			
		||||
        return 'in_features={}, out_features={}, bias={}, bits={}, group_size={}'.format(
 | 
			
		||||
            self.in_features, self.out_features, self.bias is not None, self.bits, self.group_size
 | 
			
		||||
        )
 | 
			
		||||
| 
						 | 
				
			
			@ -53,6 +53,10 @@ def is_auto_gptq_available():
 | 
			
		|||
    return importlib.util.find_spec("auto_gptq") is not None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def is_auto_awq_available():
 | 
			
		||||
    return importlib.util.find_spec("awq") is not None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def is_deepspeed_available():
 | 
			
		||||
    return importlib.util.find_spec("deepspeed") is not None
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -61,18 +65,24 @@ if is_auto_gptq_available():
 | 
			
		|||
    from auto_gptq.utils.peft_utils import QuantLinearCuda, QuantLinearCudaOld
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if is_auto_awq_available():
 | 
			
		||||
    from bigdl.llm.transformers.awq.linear import WQLinear_GEMM
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def is_linear_module(module):
 | 
			
		||||
 | 
			
		||||
    in_features = None
 | 
			
		||||
    out_features = None
 | 
			
		||||
    mp_group = None
 | 
			
		||||
 | 
			
		||||
    is_awq = is_auto_awq_available() and isinstance(module, WQLinear_GEMM)
 | 
			
		||||
 | 
			
		||||
    if is_auto_gptq_available() and isinstance(module, QuantLinearCudaOld):
 | 
			
		||||
        in_features = module.infeatures
 | 
			
		||||
        out_features = module.outfeatures
 | 
			
		||||
        mp_group = None
 | 
			
		||||
        result = True
 | 
			
		||||
    elif isinstance(module, nn.Linear):
 | 
			
		||||
    elif isinstance(module, nn.Linear) or is_awq:
 | 
			
		||||
        in_features = module.in_features
 | 
			
		||||
        out_features = module.out_features
 | 
			
		||||
        mp_group = None
 | 
			
		||||
| 
						 | 
				
			
			@ -102,8 +112,7 @@ from bigdl.llm.transformers.low_bit_linear import get_ggml_qk_size
 | 
			
		|||
Q4_1 = get_ggml_qk_size("asym_int4")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def convert_gptq(module):
 | 
			
		||||
 | 
			
		||||
def convert_gptq(module, awq=False):
 | 
			
		||||
    scales = module.scales
 | 
			
		||||
 | 
			
		||||
    zeros = torch.bitwise_right_shift(
 | 
			
		||||
| 
						 | 
				
			
			@ -111,14 +120,22 @@ def convert_gptq(module):
 | 
			
		|||
        module.wf.unsqueeze(0)).to(torch.int16 if module.bits == 8 else torch.int8)
 | 
			
		||||
    zeros = torch.bitwise_and(zeros, (2 ** module.bits) - 1)
 | 
			
		||||
 | 
			
		||||
    zeros = zeros + 1
 | 
			
		||||
    if not awq:
 | 
			
		||||
        zeros = zeros + 1
 | 
			
		||||
    zeros = zeros.reshape(scales.shape)
 | 
			
		||||
 | 
			
		||||
    weight = torch.bitwise_right_shift(
 | 
			
		||||
        torch.unsqueeze(module.qweight, 1).expand(-1, 32 // module.bits, -1),
 | 
			
		||||
        module.wf.unsqueeze(-1)).to(torch.int8)
 | 
			
		||||
    weight = torch.bitwise_and(weight, (2 ** module.bits) - 1)
 | 
			
		||||
    weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2])
 | 
			
		||||
    if awq:
 | 
			
		||||
        weight = torch.bitwise_right_shift(
 | 
			
		||||
            torch.unsqueeze(module.qweight, 2).expand(-1, -1, 32 // module.bits),
 | 
			
		||||
            module.wf.unsqueeze(0)).to(torch.int16 if module.bits == 8 else torch.int8)
 | 
			
		||||
        weight = torch.bitwise_and(weight, (2 ** module.bits) - 1)
 | 
			
		||||
        weight = weight.reshape(weight.shape[0], weight.shape[1] * weight.shape[2])
 | 
			
		||||
    else:
 | 
			
		||||
        weight = torch.bitwise_right_shift(
 | 
			
		||||
            torch.unsqueeze(module.qweight, 1).expand(-1, 32 // module.bits, -1),
 | 
			
		||||
            module.wf.unsqueeze(-1)).to(torch.int8)
 | 
			
		||||
        weight = torch.bitwise_and(weight, (2 ** module.bits) - 1)
 | 
			
		||||
        weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2])
 | 
			
		||||
 | 
			
		||||
    # convert weight to ggml format
 | 
			
		||||
    weight = weight.reshape(weight.shape[0]//module.group_size, module.group_size, weight.shape[1])
 | 
			
		||||
| 
						 | 
				
			
			@ -171,7 +188,9 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
			
		|||
                in_features, out_features, mp_group = linear_args
 | 
			
		||||
                with init_empty_weights():
 | 
			
		||||
                    new_linear = None
 | 
			
		||||
                    if is_auto_gptq_available() and isinstance(module, QuantLinearCudaOld):
 | 
			
		||||
                    is_gptq = is_auto_gptq_available() and isinstance(module, QuantLinearCudaOld)
 | 
			
		||||
                    is_awq = is_auto_awq_available() and isinstance(module, WQLinear_GEMM)
 | 
			
		||||
                    if is_gptq or is_awq:
 | 
			
		||||
                        has_bias = module.bias is not None and module.bias.abs().sum() != 0
 | 
			
		||||
                        new_linear = LowBitLinear(
 | 
			
		||||
                            in_features,
 | 
			
		||||
| 
						 | 
				
			
			@ -184,7 +203,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
			
		|||
                        invalidInputError(device_type != "meta",
 | 
			
		||||
                                          "converting from meta device is not supported")
 | 
			
		||||
                        # Copy the weights
 | 
			
		||||
                        paramsLowBit = FP4Params(data=convert_gptq(module),
 | 
			
		||||
                        paramsLowBit = FP4Params(data=convert_gptq(module, awq=is_awq),
 | 
			
		||||
                                                 requires_grad=False,
 | 
			
		||||
                                                 quantized=True,
 | 
			
		||||
                                                 _shape=(out_features, in_features),
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -13,6 +13,29 @@
 | 
			
		|||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
#
 | 
			
		||||
#
 | 
			
		||||
# MIT License
 | 
			
		||||
#
 | 
			
		||||
# Copyright (c) 2023 MIT HAN Lab
 | 
			
		||||
#
 | 
			
		||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
 | 
			
		||||
# of this software and associated documentation files (the "Software"), to deal
 | 
			
		||||
# in the Software without restriction, including without limitation the rights
 | 
			
		||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 | 
			
		||||
# copies of the Software, and to permit persons to whom the Software is
 | 
			
		||||
# furnished to do so, subject to the following conditions:
 | 
			
		||||
#
 | 
			
		||||
# The above copyright notice and this permission notice shall be included in all
 | 
			
		||||
# copies or substantial portions of the Software.
 | 
			
		||||
#
 | 
			
		||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 | 
			
		||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 | 
			
		||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 | 
			
		||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 | 
			
		||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 | 
			
		||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
 | 
			
		||||
# SOFTWARE.
 | 
			
		||||
#
 | 
			
		||||
 | 
			
		||||
import transformers
 | 
			
		||||
from transformers.configuration_utils import PretrainedConfig
 | 
			
		||||
| 
						 | 
				
			
			@ -123,6 +146,29 @@ class _BaseAutoModelClass:
 | 
			
		|||
                        from transformers import GPTQConfig
 | 
			
		||||
                        user_quantization_config = GPTQConfig(bits=4, use_exllama=False)
 | 
			
		||||
                    kwargs["quantization_config"] = user_quantization_config
 | 
			
		||||
                elif q_config["quant_method"] == "awq":
 | 
			
		||||
                    from bigdl.llm.transformers.awq.awq_config import AwqConfig
 | 
			
		||||
                    awq_config = AwqConfig.from_dict(q_config)
 | 
			
		||||
                    invalidInputError(awq_config.bits == 4,
 | 
			
		||||
                                      "Only 4-bit awq is supported in bigdl-llm.")
 | 
			
		||||
                    invalidInputError(awq_config.version == "gemm",
 | 
			
		||||
                                      "Only gemm version is supported in bigdl-llm.")
 | 
			
		||||
                    invalidInputError(awq_config.backend == "autoawq",
 | 
			
		||||
                                      "Only autoawq backend is supported in bigdl-llm.")
 | 
			
		||||
                    invalidInputError(awq_config.zero_point,
 | 
			
		||||
                                      "Only awq zero_point = True is supported in bigdl-llm.")
 | 
			
		||||
                    if load_in_low_bit is not None:
 | 
			
		||||
                        invalidInputError(load_in_low_bit == "asym_int4",
 | 
			
		||||
                                          "You can only load awq model as aysm_int4 low bit type.")
 | 
			
		||||
 | 
			
		||||
                    load_in_low_bit = "asym_int4"
 | 
			
		||||
 | 
			
		||||
                    if int(awq_config.group_size) % get_ggml_qk_size(load_in_low_bit) != 0:
 | 
			
		||||
                        invalidInputError(False,
 | 
			
		||||
                                          (f"group_size must be divisible by "
 | 
			
		||||
                                           f"{get_ggml_qk_size(load_in_low_bit)}."))
 | 
			
		||||
 | 
			
		||||
                    kwargs["quantization_config"] = awq_config
 | 
			
		||||
 | 
			
		||||
            # load int x-bit
 | 
			
		||||
            kwargs["low_cpu_mem_usage"] = True
 | 
			
		||||
| 
						 | 
				
			
			@ -156,16 +202,63 @@ class _BaseAutoModelClass:
 | 
			
		|||
        # and lead to args missing.
 | 
			
		||||
        modules_to_not_convert = kwargs.pop("modules_to_not_convert", None)
 | 
			
		||||
        replace_embedding = kwargs.pop("replace_embedding", False)
 | 
			
		||||
        quant_config = kwargs.pop("quantization_config", None)
 | 
			
		||||
        _args = copy.deepcopy(args)
 | 
			
		||||
        _kwargs = copy.deepcopy(kwargs)
 | 
			
		||||
        try:
 | 
			
		||||
            model = cls.HF_Model.from_pretrained(*args, **kwargs)
 | 
			
		||||
        except NotImplementedError:
 | 
			
		||||
            logger.info("Failed to load models with `low_cpu_mem_usage` specified, "
 | 
			
		||||
                        "will fall to traditional load method with higher memory consumption.")
 | 
			
		||||
            _kwargs["low_cpu_mem_usage"] = False
 | 
			
		||||
            model = cls.HF_Model.from_pretrained(*_args, **_kwargs)
 | 
			
		||||
            model.config.update({"bigdl_lcmu_enabled": False})
 | 
			
		||||
        awq_config = None
 | 
			
		||||
 | 
			
		||||
        if quant_config and quant_config.quant_method == "awq":
 | 
			
		||||
            # The latest transformers only support cuda version
 | 
			
		||||
            # This load awq ckpt logic is copied from
 | 
			
		||||
            # https://github.com/casper-hansen/AutoAWQ/blob/main/awq/models/base.py#L147
 | 
			
		||||
            from accelerate import init_empty_weights, infer_auto_device_map,\
 | 
			
		||||
                load_checkpoint_in_model
 | 
			
		||||
            from bigdl.llm.transformers.awq.awq import _replace_with_awq_layers,\
 | 
			
		||||
                get_layer_type, _load_config
 | 
			
		||||
            awq_config = quant_config
 | 
			
		||||
            model_weights_path, config = _load_config(args[0], '', max_new_tokens=None,
 | 
			
		||||
                                                      safetensors=True)
 | 
			
		||||
            with init_empty_weights():
 | 
			
		||||
                model = cls.HF_Model.from_config(config=config, trust_remote_code=True)
 | 
			
		||||
 | 
			
		||||
            _replace_with_awq_layers(model, awq_config=awq_config)
 | 
			
		||||
 | 
			
		||||
            model.tie_weights()
 | 
			
		||||
 | 
			
		||||
            # Get device map
 | 
			
		||||
            device_map = infer_auto_device_map(
 | 
			
		||||
                model,
 | 
			
		||||
                no_split_module_classes=[get_layer_type(config)],
 | 
			
		||||
                max_memory=None,
 | 
			
		||||
                dtype=config.torch_dtype
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            # Load checkpoint
 | 
			
		||||
            load_checkpoint_in_model(
 | 
			
		||||
                model,
 | 
			
		||||
                checkpoint=model_weights_path,
 | 
			
		||||
                device_map=device_map,
 | 
			
		||||
                offload_folder=None,
 | 
			
		||||
                dtype=config.torch_dtype
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            # Offloading dispatch
 | 
			
		||||
            from accelerate import dispatch_model
 | 
			
		||||
            model = dispatch_model(
 | 
			
		||||
                model,
 | 
			
		||||
                device_map=device_map,
 | 
			
		||||
                offload_dir=None
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            try:
 | 
			
		||||
                model = cls.HF_Model.from_pretrained(*args, **kwargs)
 | 
			
		||||
            except NotImplementedError:
 | 
			
		||||
                logger.info("Failed to load models with `low_cpu_mem_usage` specified, "
 | 
			
		||||
                            "will fall to traditional load method with higher memory consumption.")
 | 
			
		||||
                _kwargs["low_cpu_mem_usage"] = False
 | 
			
		||||
                model = cls.HF_Model.from_pretrained(*_args, **_kwargs)
 | 
			
		||||
                model.config.update({"bigdl_lcmu_enabled": False})
 | 
			
		||||
 | 
			
		||||
        model = model.to("cpu")
 | 
			
		||||
        model = ggml_convert_low_bit(model, qtype, optimize_model,
 | 
			
		||||
                                     modules_to_not_convert=modules_to_not_convert,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue