diff --git a/python/llm/example/CPU/HF-Transformers-AutoModels/Advanced-Quantizations/AWQ/README.md b/python/llm/example/CPU/HF-Transformers-AutoModels/Advanced-Quantizations/AWQ/README.md new file mode 100644 index 00000000..6474d507 --- /dev/null +++ b/python/llm/example/CPU/HF-Transformers-AutoModels/Advanced-Quantizations/AWQ/README.md @@ -0,0 +1,71 @@ +# AWQ +This example shows how to directly run 4-bit AWQ models using BigDL-LLM on Intel CPU. For illustration purposes, we utilize the ["TheBloke/Llama-2-7B-Chat-AWQ"](https://huggingface.co/TheBloke/Llama-2-7B-Chat-AWQ) as a reference. + +## 0. Requirements +To run these examples with BigDL-LLM, we have some recommended requirements for your machine, please refer to [here](../README.md#recommended-requirements) for more information. + +## Example: Predict Tokens using `generate()` API +In the example [generate.py](./generate.py), we show a basic use case for a Llama2 model to predict the next N tokens using `generate()` API, with BigDL-LLM INT4 optimizations. +### 1. Install +We suggest using conda to manage environment: +```bash +conda create -n llm python=3.9 +conda activate llm + +pip install autoawq==0.1.6 --no-deps +pip install bigdl-llm[all] # install bigdl-llm with 'all' option +pip install transformers==4.35.0 +pip install accelerate==0.24.1 +``` + +### 2. Run +``` +python ./generate.py --repo-id-or-model-path REPO_ID_OR_MODEL_PATH --prompt PROMPT --n-predict N_PREDICT +``` + +Arguments info: +- `--repo-id-or-model-path REPO_ID_OR_MODEL_PATH`: argument defining the huggingface repo id for the Llama2-awq model (e.g. `TheBloke/Llama-2-7B-Chat-AWQ`) to be downloaded, or the path to the huggingface checkpoint folder. It is default to be `'TheBloke/Llama-2-7B-Chat-AWQ'`. +- `--prompt PROMPT`: argument defining the prompt to be infered (with integrated prompt format for chat). It is default to be `'What is AI?'`. +- `--n-predict N_PREDICT`: argument defining the max number of tokens to predict. It is default to be `32`. + +> **Note**: When loading the model in 4-bit, BigDL-LLM converts linear layers in the model into INT4 format. In theory, a *X*B model saved in 16-bit will requires approximately 2*X* GB of memory for loading, and ~0.5*X* GB memory for further inference. +> +> Please select the appropriate size of the Llama2 model based on the capabilities of your machine. + +#### 2.1 Client +On client Windows machine, it is recommended to run directly with full utilization of all cores: +```powershell +python ./generate.py +``` + +#### 2.2 Server +For optimal performance on server, it is recommended to set several environment variables (refer to [here](../README.md#best-known-configuration-on-linux) for more information), and run the example with all the physical cores of a single socket. + +E.g. on Linux, +```bash +# set BigDL-Nano env variables +source bigdl-llm-init + +# e.g. for a server with 48 cores per socket +export OMP_NUM_THREADS=48 +numactl -C 0-47 -m 0 python ./generate.py +``` + +#### 2.3 Sample Output +#### ["TheBloke/Llama-2-7B-Chat-AWQ"](https://huggingface.co/TheBloke/Llama-2-7B-Chat-AWQ) +```log +Inference time: xxxx s +-------------------- Prompt -------------------- +### HUMAN: +What is AI? + +### RESPONSE: + +-------------------- Output -------------------- +### HUMAN: +What is AI? + +### RESPONSE: + +Artificial intelligence (AI) is the ability of machines to perform tasks that typically require human intelligence, such as learning, problem-solving, decision +``` diff --git a/python/llm/example/CPU/HF-Transformers-AutoModels/Advanced-Quantizations/AWQ/generate.py b/python/llm/example/CPU/HF-Transformers-AutoModels/Advanced-Quantizations/AWQ/generate.py new file mode 100644 index 00000000..c9e7c066 --- /dev/null +++ b/python/llm/example/CPU/HF-Transformers-AutoModels/Advanced-Quantizations/AWQ/generate.py @@ -0,0 +1,71 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import torch +import time +import argparse + +from bigdl.llm.transformers import AutoModelForCausalLM +from transformers import LlamaTokenizer + +# you could tune the prompt based on your own model, +# here the prompt tuning refers to https://huggingface.co/georgesung/llama2_7b_chat_uncensored#prompt-style +LLAMA2_PROMPT_FORMAT = """### HUMAN: +{prompt} + +### RESPONSE: +""" + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for Llama2 model') + parser.add_argument('--repo-id-or-model-path', type=str, default="TheBloke/Llama-2-7B-Chat-AWQ", + help='The huggingface repo id' + ', or the path to the huggingface checkpoint folder') + parser.add_argument('--prompt', type=str, default="What is AI?", + help='Prompt to infer') + parser.add_argument('--n-predict', type=int, default=32, + help='Max tokens to predict') + + args = parser.parse_args() + model_path = args.repo_id_or_model_path + + # Load model in 4 bit, + # which convert the relevant layers in the model into INT4 format + model = AutoModelForCausalLM.from_pretrained(model_path, + load_in_4bit=True, + trust_remote_code=True) + + # Load tokenizer + tokenizer = LlamaTokenizer.from_pretrained(model_path, trust_remote_code=True) + + # Generate predicted tokens + with torch.inference_mode(): + prompt = LLAMA2_PROMPT_FORMAT.format(prompt=args.prompt) + input_ids = tokenizer.encode(prompt, return_tensors="pt") + st = time.time() + # if your selected model is capable of utilizing previous key/value attentions + # to enhance decoding speed, but has `"use_cache": false` in its model config, + # it is important to set `use_cache=True` explicitly in the `generate` function + # to obtain optimal performance with BigDL-LLM INT4 optimizations + output = model.generate(input_ids, + max_new_tokens=args.n_predict) + end = time.time() + output_str = tokenizer.decode(output[0], skip_special_tokens=True) + print(f'Inference time: {end-st} s') + print('-'*20, 'Prompt', '-'*20) + print(prompt) + print('-'*20, 'Output', '-'*20) + print(output_str) diff --git a/python/llm/example/GPU/HF-Transformers-AutoModels/Advanced-Quantizations/AWQ/README.md b/python/llm/example/GPU/HF-Transformers-AutoModels/Advanced-Quantizations/AWQ/README.md new file mode 100644 index 00000000..f223d5e9 --- /dev/null +++ b/python/llm/example/GPU/HF-Transformers-AutoModels/Advanced-Quantizations/AWQ/README.md @@ -0,0 +1,65 @@ +# AWQ +This example shows how to directly run 4-bit AWQ models using BigDL-LLM on Intel GPU. For illustration purposes, we utilize the ["TheBloke/Llama-2-7B-Chat-AWQ"](https://huggingface.co/TheBloke/Llama-2-7B-Chat-AWQ) as a reference. + +## 0. Requirements +To run these examples with BigDL-LLM, we have some recommended requirements for your machine, please refer to [here](../README.md#recommended-requirements) for more information. + +## Example: Predict Tokens using `generate()` API +In the example [generate.py](./generate.py), we show a basic use case for a Llama2 model to predict the next N tokens using `generate()` API, with BigDL-LLM INT4 optimizations. +### 1. Install +We suggest using conda to manage environment: +```bash +conda create -n llm python=3.9 +conda activate llm + +pip install --pre --upgrade bigdl-llm[xpu] -f https://developer.intel.com/ipex-whl-stable-xpu +pip install transformers==4.35.0 +pip install autoawq==0.1.6 --no-deps +pip install accelerate==0.24.1 +``` + +### 2. Configures OneAPI environment variables +```bash +source /opt/intel/oneapi/setvars.sh +``` + +### 3. Run + +For optimal performance on Arc, it is recommended to set several environment variables. + +```bash +export USE_XETLA=OFF +export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1 +``` + +``` +python ./generate.py --repo-id-or-model-path REPO_ID_OR_MODEL_PATH --prompt PROMPT --n-predict N_PREDICT +``` + +Arguments info: +- `--repo-id-or-model-path REPO_ID_OR_MODEL_PATH`: argument defining the huggingface repo id for the Llama2-awq model (e.g. `TheBloke/Llama-2-7B-Chat-AWQ`) to be downloaded, or the path to the huggingface checkpoint folder. It is default to be `'TheBloke/Llama-2-7B-Chat-AWQ'`. +- `--prompt PROMPT`: argument defining the prompt to be infered (with integrated prompt format for chat). It is default to be `'What is AI?'`. +- `--n-predict N_PREDICT`: argument defining the max number of tokens to predict. It is default to be `32`. + +> **Note**: When loading the model in 4-bit, BigDL-LLM converts linear layers in the model into INT4 format. In theory, a *X*B model saved in 16-bit will requires approximately 2*X* GB of memory for loading, and ~0.5*X* GB memory for further inference. +> +> Please select the appropriate size of the Llama2 model based on the capabilities of your machine. + +#### 2.3 Sample Output +#### ["TheBloke/Llama-2-7B-Chat-AWQ"](https://huggingface.co/TheBloke/Llama-2-7B-Chat-AWQ) +```log +Inference time: xxxx s +-------------------- Prompt -------------------- +### HUMAN: +What is AI? + +### RESPONSE: + +-------------------- Output -------------------- +### HUMAN: +What is AI? + +### RESPONSE: + +Artificial intelligence (AI) is the ability of machines to perform tasks that typically require human intelligence, such as learning, problem-solving, decision +``` \ No newline at end of file diff --git a/python/llm/example/GPU/HF-Transformers-AutoModels/Advanced-Quantizations/AWQ/generate.py b/python/llm/example/GPU/HF-Transformers-AutoModels/Advanced-Quantizations/AWQ/generate.py new file mode 100644 index 00000000..9e9b72df --- /dev/null +++ b/python/llm/example/GPU/HF-Transformers-AutoModels/Advanced-Quantizations/AWQ/generate.py @@ -0,0 +1,71 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import torch +import time +import argparse +import intel_extension_for_pytorch as ipex +from bigdl.llm.transformers import AutoModelForCausalLM +from transformers import LlamaTokenizer + +# you could tune the prompt based on your own model, +# here the prompt tuning refers to https://huggingface.co/georgesung/llama2_7b_chat_uncensored#prompt-style +LLAMA2_PROMPT_FORMAT = """### HUMAN: +{prompt} + +### RESPONSE: +""" + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for Llama2 model') + parser.add_argument('--repo-id-or-model-path', type=str, default="TheBloke/Llama-2-7B-Chat-AWQ", + help='The huggingface repo id' + ', or the path to the huggingface checkpoint folder') + parser.add_argument('--prompt', type=str, default="What is AI?", + help='Prompt to infer') + parser.add_argument('--n-predict', type=int, default=32, + help='Max tokens to predict') + + args = parser.parse_args() + model_path = args.repo_id_or_model_path + + # Load model in 4 bit, + # which convert the relevant layers in the model into INT4 format + model = AutoModelForCausalLM.from_pretrained(model_path, + load_in_4bit=True, + trust_remote_code=True,).to("xpu") + + # Load tokenizer + tokenizer = LlamaTokenizer.from_pretrained(model_path, trust_remote_code=True) + + # Generate predicted tokens + with torch.inference_mode(): + prompt = LLAMA2_PROMPT_FORMAT.format(prompt=args.prompt) + input_ids = tokenizer.encode(prompt, return_tensors="pt").to("xpu") + st = time.time() + # if your selected model is capable of utilizing previous key/value attentions + # to enhance decoding speed, but has `"use_cache": false` in its model config, + # it is important to set `use_cache=True` explicitly in the `generate` function + # to obtain optimal performance with BigDL-LLM INT4 optimizations + output = model.generate(input_ids, + max_new_tokens=args.n_predict) + end = time.time() + output_str = tokenizer.decode(output[0], skip_special_tokens=True) + print(f'Inference time: {end-st} s') + print('-'*20, 'Prompt', '-'*20) + print(prompt) + print('-'*20, 'Output', '-'*20) + print(output_str) diff --git a/python/llm/src/bigdl/llm/transformers/awq/__init__.py b/python/llm/src/bigdl/llm/transformers/awq/__init__.py new file mode 100644 index 00000000..9db92c50 --- /dev/null +++ b/python/llm/src/bigdl/llm/transformers/awq/__init__.py @@ -0,0 +1,21 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# This would makes sure Python is aware there is more than one sub-package within bigdl, +# physically located elsewhere. +# Otherwise there would be module not found error in non-pip's setting as Python would +# only search the first bigdl package and end up finding only one sub-package. + diff --git a/python/llm/src/bigdl/llm/transformers/awq/act.py b/python/llm/src/bigdl/llm/transformers/awq/act.py new file mode 100644 index 00000000..6ab6dad2 --- /dev/null +++ b/python/llm/src/bigdl/llm/transformers/awq/act.py @@ -0,0 +1,54 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# =========================================================================== +# +# This file is copied from +# https://github.com/casper-hansen/AutoAWQ/blob/main/awq/modules/act.py +# +# MIT License +# +# Copyright (c) 2023 MIT HAN Lab +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# + +import torch.nn as nn + + +class ScaledActivation(nn.Module): + def __init__(self, module, scales): + super().__init__() + self.act = module + self.scales = nn.Parameter(scales.data) + + def forward(self, x): + return self.act(x) / self.scales.view(1, 1, -1).to(x.device) diff --git a/python/llm/src/bigdl/llm/transformers/awq/awq.py b/python/llm/src/bigdl/llm/transformers/awq/awq.py new file mode 100644 index 00000000..511d9ede --- /dev/null +++ b/python/llm/src/bigdl/llm/transformers/awq/awq.py @@ -0,0 +1,223 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# =========================================================================== +# +# This file is adapted from +# https://github.com/casper-hansen/AutoAWQ/blob/main/awq/models/base.py#L147 +# and https://github.com/mit-han-lab/llm-awq/blob/main/awq/quantize/quantizer.py +# +# MIT License +# +# Copyright (c) 2023 MIT HAN Lab +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# + +import gc +import os +from tqdm import tqdm +import torch +import torch.nn as nn +from transformers.models.bloom.modeling_bloom import BloomForCausalLM +from transformers.models.opt.modeling_opt import OPTForCausalLM +from transformers.models.llama.modeling_llama import LlamaForCausalLM +from transformers.models.bloom.modeling_bloom import BloomBlock +from transformers import AwqConfig, AutoConfig +from bigdl.llm.transformers.awq.linear import WQLinear_GEMM, WQLinear_GEMV +from huggingface_hub import snapshot_download +from bigdl.llm.utils.common import invalidInputError + + +layer_type_dict = { + "mpt": "MPTBlock", + "llama": "LlamaDecoderLayer", + "opt": "OPTDecoderLayer", + "RefinedWeb": "FalconDecoderLayer", + "RefinedWebModel": "FalconDecoderLayer", + "falcon": "FalconDecoderLayer", + "bloom": "BloomBlock", + "gptj": "GPTJBlock", + "gpt_bigcode": "GPTBigCodeBlock", + "mistral": "MistralDecoderLayer", + "gpt_neox": "GPTNeoXDecoderLayer", + "aquila": "AquilaDecoderLayer", +} + + +def set_op_by_name(layer, name, new_module): + levels = name.split('.') + if len(levels) > 1: + mod_ = layer + for l_idx in range(len(levels)-1): + if levels[l_idx].isdigit(): + mod_ = mod_[int(levels[l_idx])] + else: + mod_ = getattr(mod_, levels[l_idx]) + setattr(mod_, levels[-1], new_module) + else: + setattr(layer, name, new_module) + + +def _load_config(model_path, model_filename, safetensors=False, + trust_remote_code=True, max_new_tokens=4096): + # [STEP 1] Download model if path is not a directory + if not os.path.isdir(model_path): + ignore_patterns = ["*msgpack*", "*h5*"] + if safetensors: + ignore_patterns.extend(["*.pt*", "*.bin*"]) + else: + ignore_patterns.append("*.safetensors*") + + model_path = snapshot_download(model_path, ignore_patterns=ignore_patterns) + + if model_filename != '': + model_weights_path = model_path + f'/{model_filename}' + else: + model_weights_path = model_path + + # Load model config and set max generation length + max_new_tokens = 2048 if max_new_tokens is None else max_new_tokens + config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code) + config.max_new_tokens = max_new_tokens + + return model_weights_path, config + + +def get_named_linears(module): + return {name: m for name, m in module.named_modules() if isinstance(m, nn.Linear)} + + +def get_blocks(model): + if model.__class__.__name__ == 'LlamaForCausalLM': + layers = model.model.layers + elif isinstance(model, OPTForCausalLM): + layers = model.model.decoder.layers + elif isinstance(model, BloomForCausalLM): + layers = model.transformer.h + elif "mpt" in str(model.__class__).lower(): + layers = model.transformer.blocks + elif "falcon" in str(model.__class__).lower(): + layers = model.transformer.h + elif "bigcode" in str(model.__class__).lower(): + layers = model.transformer.h + elif "neox" in str(model.__class__).lower(): + layers = model.gpt_neox.layers + else: + invalidInputError(False, f"Model type {type(model)} isn't supported.") + return layers + + +def get_layer_type(config): + if config.model_type not in layer_type_dict.keys(): + invalidInputError(False, f"{config.model_type} isn't supported yet.") + return layer_type_dict[config.model_type] + + +def scale_activations(module): + from bigdl.llm.transformers.awq.act import ScaledActivation + param = next(module.parameters()) + dtype = param.dtype + device = param.device + if isinstance(module, BloomBlock): + if isinstance(module.mlp.gelu_impl, ScaledActivation): + return + c = module.mlp.dense_h_to_4h.out_features + act = ScaledActivation( + module.mlp.gelu_impl, + torch.ones(c, dtype=dtype, device=device) + ) + set_op_by_name(module, "mlp.gelu_impl", act) + elif 'mptblock' in str(module.__class__.__name__).lower(): + if isinstance(module.ffn.act, ScaledActivation): + return + c = module.ffn.up_proj.out_features + act = ScaledActivation( + module.ffn.act, + torch.ones(c, dtype=dtype, device=device) + ) + set_op_by_name(module, "ffn.act", act) + elif 'falcon' in str(module.__class__).lower(): + if isinstance(module.mlp.act, ScaledActivation): + return + c = module.mlp.dense_h_to_4h.out_features + act = ScaledActivation( + module.mlp.act, + torch.ones(c, dtype=dtype, device=device) + ) + set_op_by_name(module, "mlp.act", act) + elif 'bigcode' in str(module.__class__).lower(): + if isinstance(module.mlp.act, ScaledActivation): + return + c = module.mlp.c_proj.out_features + act = ScaledActivation( + module.mlp.act, + torch.ones(c, dtype=dtype, device=device) + ) + set_op_by_name(module, "mlp.act", act) + elif 'neox' in str(module.__class__).lower(): + if isinstance(module.mlp.act, ScaledActivation): + return + c = module.mlp.dense_h_to_4h.out_features + act = ScaledActivation( + module.mlp.act, + torch.ones(c, dtype=dtype, device=device) + ) + set_op_by_name(module, "mlp.act", act) + + +def _replace_with_awq_layers(model, awq_config: AwqConfig): + layers = get_blocks(model) + + for i in tqdm(range(len(layers)), desc="Replacing layers..."): + layer = layers[i] + + # Get every linear layer in a block + named_linears = get_named_linears(layer) + + # Replace activation functions + scale_activations(layer) + + # Replace nn.Linear with WQLinear + for name, module in named_linears.items(): + if awq_config.version == 'gemm': + q_linear_module = WQLinear_GEMM + elif awq_config.version == 'gemv': + q_linear_module = WQLinear_GEMV + + q_linear = q_linear_module.from_linear(module, + awq_config.bits, + awq_config.group_size, + True) + q_linear.to(next(layer.parameters()).device) + set_op_by_name(layer, name, q_linear) + + gc.collect() diff --git a/python/llm/src/bigdl/llm/transformers/awq/awq_config.py b/python/llm/src/bigdl/llm/transformers/awq/awq_config.py new file mode 100644 index 00000000..0f60a833 --- /dev/null +++ b/python/llm/src/bigdl/llm/transformers/awq/awq_config.py @@ -0,0 +1,99 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# =========================================================================== +# +# This file is copied from +# https://github.com/huggingface/transformers +# +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from dataclasses import dataclass +from bigdl.llm.utils.common import invalidInputError +from transformers.utils.quantization_config import QuantizationConfigMixin +from transformers.utils.quantization_config import AwqBackendPackingMethod,\ + AWQLinearVersion, QuantizationMethod + + +@dataclass +class AwqConfig(QuantizationConfigMixin): + """ + This is a wrapper class about all possible attributes and features that you can + play with a model that has been loaded using `auto-awq` library awq quantization + relying on auto_awq backend. + + Args: + bits (`int`, *optional*, defaults to 4): + The number of bits to quantize to. + group_size (`int`, *optional*, defaults to 128): + The group size to use for quantization. + Recommended value is 128 and -1 uses per-column quantization. + zero_point (`bool`, *optional*, defaults to `True`): + Whether to use zero point quantization. + version (`AWQLinearVersion`, *optional*, defaults to + `AWQLinearVersion.GEMM`): + The version of the quantization algorithm to use. + GEMM is better for big batch_size (e.g. >= 8) otherwise, + GEMV is better (e.g. < 8 ) + backend (`AwqBackendPackingMethod`, *optional*, defaults to + `AwqBackendPackingMethod.AUTOAWQ`): + The quantization backend. Some models might be quantized using `llm-awq` backend. + This is useful for users that quantize their own models using `llm-awq` library. + """ + + def __init__( + self, + bits: int = 4, + group_size: int = 128, + zero_point: bool = True, + version: AWQLinearVersion = AWQLinearVersion.GEMM, + backend: AwqBackendPackingMethod = AwqBackendPackingMethod.AUTOAWQ, + **kwargs, + ): + self.quant_method = QuantizationMethod.AWQ + + self.bits = bits + self.group_size = group_size + self.zero_point = zero_point + self.version = version + self.backend = backend + + self.post_init() + + def post_init(self): + r""" + Safety checker that arguments are correct + """ + invalidInputError(self.backend == AwqBackendPackingMethod.AUTOAWQ, + "Only supported quantization backends in " + f"{AwqBackendPackingMethod.AUTOAWQ} - " + f"not recognized backend {self.backend}") + + invalidInputError(self.version in [AWQLinearVersion.GEMM, AWQLinearVersion.GEMV], + "Only supported versions are in [AWQLinearVersion.GEMM," + f"AWQLinearVersion.GEMV] - not recognized version {self.version}") diff --git a/python/llm/src/bigdl/llm/transformers/awq/linear.py b/python/llm/src/bigdl/llm/transformers/awq/linear.py new file mode 100644 index 00000000..acce09a9 --- /dev/null +++ b/python/llm/src/bigdl/llm/transformers/awq/linear.py @@ -0,0 +1,284 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# =========================================================================== +# +# This file is adapted from +# https://github.com/casper-hansen/AutoAWQ/blob/main/awq/modules/linear.py +# +# MIT License +# +# Copyright (c) 2023 MIT HAN Lab +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# + +import torch +import torch.nn as nn +from bigdl.llm.utils.common import invalidOperationError, invalidInputError + + +def make_divisible(c, divisor): + return (c + divisor - 1) // divisor + + +def calculate_zeros_width(in_features, group_size=128, pack_num=8): + if group_size >= 128: + size_multiplier = 1 + elif group_size == 64: + size_multiplier = 2 + elif group_size == 32: + size_multiplier = 4 + else: + invalidOperationError(False, + f"Not implemented group size {group_size}.") + + base_width = make_divisible(in_features // group_size, pack_num) + base_width = make_divisible(base_width, size_multiplier) * size_multiplier + return base_width + + +class WQLinear_GEMM(nn.Module): + def __init__(self, bits, group_size, in_features, out_features, bias, dev): + super().__init__() + + invalidOperationError(bits == 4, "Only 4-bit are supported for now.") + + self.in_features = in_features + self.out_features = out_features + self.bits = bits + self.group_size = group_size if group_size != -1 else in_features + + self.wf = (torch.tensor([0, 4, 1, 5, 2, 6, 3, 7], + dtype=torch.int32) * self.bits).unsqueeze(0) + + # quick sanity check (make sure aligment) + invalidInputError(self.in_features % self.group_size == 0, + f"Invalid in_features number {self.in_features}.") + invalidInputError(out_features % (32 // self.bits) == 0, + f"Invalid out_features number {out_features}.") + + self.register_buffer('qweight', + torch.zeros((in_features, + out_features // (32 // self.bits)), + dtype=torch.int32, device=dev)) + self.register_buffer('qzeros', + torch.zeros((in_features // self.group_size, + out_features // (32 // self.bits)), + dtype=torch.int32, device=dev)) + self.register_buffer('scales', + torch.zeros((in_features // self.group_size, out_features), + dtype=torch.float16, device=dev)) + if bias: + self.register_buffer('bias', torch.zeros((out_features), dtype=torch.float16, + device=dev)) + else: + self.bias = None + + @classmethod + def from_linear(cls, linear, bits, group_size, init_only=False, scales=None, zeros=None): + awq_linear = cls(bits, group_size, linear.in_features, linear.out_features, + linear.bias is not None, linear.weight.device) + if init_only: # just prepare for loading sd + return awq_linear + + # need scales and zeros info for real quantization + invalidInputError(scales is not None and zeros is not None, + "Scales and zeros should not be None.") + scale_zeros = zeros * scales + + awq_linear.scales = scales.clone().half() + if linear.bias is not None: + awq_linear.bias = linear.bias.clone().half() + + pack_num = 32 // awq_linear.bits + + intweight = [] + for idx in range(awq_linear.in_features): + intweight.append( + torch.round((linear.weight.data[:, idx] + + scale_zeros[idx // group_size]) / + awq_linear.scales[idx // group_size]).to(torch.int)[:, None]) + intweight = torch.cat(intweight, dim=1) + intweight = intweight.t().contiguous() + intweight = intweight.to(dtype=torch.int32) + qweight = torch.zeros((intweight.shape[0], + intweight.shape[1] // (32 // awq_linear.bits)), + dtype=torch.int32, device=intweight.device) + + torch.set_printoptions(threshold=10_000) + print(intweight) + + for col in range(intweight.shape[1] // pack_num): + if awq_linear.bits == 4: + order_map = [0, 2, 4, 6, 1, 3, 5, 7] + else: + invalidOperationError(False, "Only 4-bit are supported for now.") + for i in range(pack_num): + qweight_col = intweight[:, col * pack_num + order_map[i]] + qweight[:, col] |= qweight_col << (i * awq_linear.bits) + awq_linear.qweight = qweight + + zeros = zeros.to(dtype=torch.int32) + qzeros = torch.zeros((zeros.shape[0], zeros.shape[1] // (32 // awq_linear.bits)), + dtype=torch.int32, device=zeros.device) + + for col in range(zeros.shape[1] // pack_num): + if awq_linear.bits == 4: + order_map = [0, 2, 4, 6, 1, 3, 5, 7] + else: + invalidOperationError(False, "Only 4-bit are supported for now.") + for i in range(pack_num): + qzero_col = zeros[:, col * pack_num + order_map[i]] + qzeros[:, col] |= qzero_col << (i * awq_linear.bits) + awq_linear.qzeros = qzeros + + return awq_linear + + @torch.no_grad() + def forward(self, x): + invalidOperationError(False, "Bigdl-llm does not support inference awq models directly.") + + def extra_repr(self) -> str: + return 'in_features={}, out_features={}, bias={}, bits={}, group_size={}'.format( + self.in_features, self.out_features, self.bias is not None, self.bits, self.group_size + ) + + +class WQLinear_GEMV(nn.Module): + def __init__(self, bits, group_size, in_features, out_features, bias, dev): + super().__init__() + + invalidOperationError(bits == 4, "Only 4-bit are supported for now.") + + self.in_features = in_features + self.out_features = out_features + self.bits = bits + self.group_size = group_size if group_size != -1 else in_features + self.split_k_iters = 8 + + # quick sanity check (make sure aligment) + invalidInputError(self.in_features % self.group_size == 0, + f"Invalid in_features number {self.in_features}.") + invalidInputError(out_features % (32 // self.bits) == 0, + f"Invalid out_features number {out_features}.") + pack_num = (32 // self.bits) + + self.register_buffer('qweight', + torch.zeros((out_features, in_features // pack_num), + dtype=torch.int32, device=dev)) + self.register_buffer('qzeros', + torch.zeros((out_features, + calculate_zeros_width(in_features, + self.group_size)), + dtype=torch.int32, device=dev)) + self.register_buffer('scales', + torch.zeros((out_features, + calculate_zeros_width(in_features, self.group_size) + * pack_num), dtype=torch.float16, device=dev)) + if bias: + self.register_buffer('bias', torch.zeros((out_features), + dtype=torch.float16, device=dev)) + else: + self.bias = None + + @classmethod + def from_linear(cls, linear, bits, group_size, init_only=False, scales=None, zeros=None): + awq_linear = cls(bits, group_size, linear.in_features, linear.out_features, + linear.bias is not None, linear.weight.device) + if init_only: # just prepare for loading sd + return awq_linear + + # need scales and zeros info for real quantization + invalidInputError(scales is not None and zeros is not None, + "Scales and zeros should not be None.") + scale_zeros = zeros * scales + + pack_num = 32 // awq_linear.bits + qscales = torch.zeros( + (scales.shape[0], calculate_zeros_width(linear.in_features, group_size) * pack_num), + dtype=torch.float16, + device=scales.device + ) + qscales[:, :scales.shape[1]] = scales + awq_linear.scales = qscales + if linear.bias is not None: + awq_linear.bias = linear.bias.clone().half() + + intweight = [] + for idx in range(awq_linear.in_features): + intweight.append( + torch.round((linear.weight.data[:, idx] + + scale_zeros[:, idx // group_size]) / + awq_linear.scales[:, idx // group_size]).to(torch.int)[:, None]) + intweight = torch.cat(intweight, dim=1) + intweight = intweight.to(dtype=torch.int32) + qweight = torch.zeros((intweight.shape[0], intweight.shape[1] // 32 * awq_linear.bits), + dtype=torch.int32, device=intweight.device) + + for col in range(intweight.shape[1] // pack_num): + if awq_linear.bits == 4: + order_map = [0, 1, 2, 3, 4, 5, 6, 7] + else: + invalidOperationError(False, "Only 4-bit are supported for now.") + for i in range(pack_num): + qweight_col = intweight[:, col * pack_num + order_map[i]] + qweight[:, col] |= qweight_col << (i * awq_linear.bits) + awq_linear.qweight = qweight + + zeros = zeros.to(dtype=torch.int32) + qzeros = torch.zeros( + (zeros.shape[0], calculate_zeros_width(linear.in_features, group_size)), + dtype=torch.int32, + device=zeros.device, + ) + + for col in range((zeros.shape[1] + pack_num - 1) // pack_num): + if awq_linear.bits == 4: + order_map = [0, 1, 2, 3, 4, 5, 6, 7] + else: + invalidOperationError(False, "Only 4-bit are supported for now.") + for i in range(pack_num): + if col * pack_num + order_map[i] >= zeros.shape[1]: + continue + qzero_col = zeros[:, col * pack_num + order_map[i]] + qzeros[:, col] |= qzero_col << (i * awq_linear.bits) + awq_linear.qzeros = qzeros + return awq_linear + + @torch.no_grad() + def forward(self, x): + invalidOperationError(False, "Bigdl-llm does not support inference awq models directly.") + + def extra_repr(self) -> str: + return 'in_features={}, out_features={}, bias={}, bits={}, group_size={}'.format( + self.in_features, self.out_features, self.bias is not None, self.bits, self.group_size + ) diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index a1a9f7f1..64e18511 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -53,6 +53,10 @@ def is_auto_gptq_available(): return importlib.util.find_spec("auto_gptq") is not None +def is_auto_awq_available(): + return importlib.util.find_spec("awq") is not None + + def is_deepspeed_available(): return importlib.util.find_spec("deepspeed") is not None @@ -61,18 +65,24 @@ if is_auto_gptq_available(): from auto_gptq.utils.peft_utils import QuantLinearCuda, QuantLinearCudaOld +if is_auto_awq_available(): + from bigdl.llm.transformers.awq.linear import WQLinear_GEMM + + def is_linear_module(module): in_features = None out_features = None mp_group = None + is_awq = is_auto_awq_available() and isinstance(module, WQLinear_GEMM) + if is_auto_gptq_available() and isinstance(module, QuantLinearCudaOld): in_features = module.infeatures out_features = module.outfeatures mp_group = None result = True - elif isinstance(module, nn.Linear): + elif isinstance(module, nn.Linear) or is_awq: in_features = module.in_features out_features = module.out_features mp_group = None @@ -102,8 +112,7 @@ from bigdl.llm.transformers.low_bit_linear import get_ggml_qk_size Q4_1 = get_ggml_qk_size("asym_int4") -def convert_gptq(module): - +def convert_gptq(module, awq=False): scales = module.scales zeros = torch.bitwise_right_shift( @@ -111,14 +120,22 @@ def convert_gptq(module): module.wf.unsqueeze(0)).to(torch.int16 if module.bits == 8 else torch.int8) zeros = torch.bitwise_and(zeros, (2 ** module.bits) - 1) - zeros = zeros + 1 + if not awq: + zeros = zeros + 1 zeros = zeros.reshape(scales.shape) - weight = torch.bitwise_right_shift( - torch.unsqueeze(module.qweight, 1).expand(-1, 32 // module.bits, -1), - module.wf.unsqueeze(-1)).to(torch.int8) - weight = torch.bitwise_and(weight, (2 ** module.bits) - 1) - weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2]) + if awq: + weight = torch.bitwise_right_shift( + torch.unsqueeze(module.qweight, 2).expand(-1, -1, 32 // module.bits), + module.wf.unsqueeze(0)).to(torch.int16 if module.bits == 8 else torch.int8) + weight = torch.bitwise_and(weight, (2 ** module.bits) - 1) + weight = weight.reshape(weight.shape[0], weight.shape[1] * weight.shape[2]) + else: + weight = torch.bitwise_right_shift( + torch.unsqueeze(module.qweight, 1).expand(-1, 32 // module.bits, -1), + module.wf.unsqueeze(-1)).to(torch.int8) + weight = torch.bitwise_and(weight, (2 ** module.bits) - 1) + weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2]) # convert weight to ggml format weight = weight.reshape(weight.shape[0]//module.group_size, module.group_size, weight.shape[1]) @@ -171,7 +188,9 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, in_features, out_features, mp_group = linear_args with init_empty_weights(): new_linear = None - if is_auto_gptq_available() and isinstance(module, QuantLinearCudaOld): + is_gptq = is_auto_gptq_available() and isinstance(module, QuantLinearCudaOld) + is_awq = is_auto_awq_available() and isinstance(module, WQLinear_GEMM) + if is_gptq or is_awq: has_bias = module.bias is not None and module.bias.abs().sum() != 0 new_linear = LowBitLinear( in_features, @@ -184,7 +203,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, invalidInputError(device_type != "meta", "converting from meta device is not supported") # Copy the weights - paramsLowBit = FP4Params(data=convert_gptq(module), + paramsLowBit = FP4Params(data=convert_gptq(module, awq=is_awq), requires_grad=False, quantized=True, _shape=(out_features, in_features), diff --git a/python/llm/src/bigdl/llm/transformers/model.py b/python/llm/src/bigdl/llm/transformers/model.py index 74aae079..d0e40444 100644 --- a/python/llm/src/bigdl/llm/transformers/model.py +++ b/python/llm/src/bigdl/llm/transformers/model.py @@ -13,6 +13,29 @@ # See the License for the specific language governing permissions and # limitations under the License. # +# +# MIT License +# +# Copyright (c) 2023 MIT HAN Lab +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# import transformers from transformers.configuration_utils import PretrainedConfig @@ -123,6 +146,29 @@ class _BaseAutoModelClass: from transformers import GPTQConfig user_quantization_config = GPTQConfig(bits=4, use_exllama=False) kwargs["quantization_config"] = user_quantization_config + elif q_config["quant_method"] == "awq": + from bigdl.llm.transformers.awq.awq_config import AwqConfig + awq_config = AwqConfig.from_dict(q_config) + invalidInputError(awq_config.bits == 4, + "Only 4-bit awq is supported in bigdl-llm.") + invalidInputError(awq_config.version == "gemm", + "Only gemm version is supported in bigdl-llm.") + invalidInputError(awq_config.backend == "autoawq", + "Only autoawq backend is supported in bigdl-llm.") + invalidInputError(awq_config.zero_point, + "Only awq zero_point = True is supported in bigdl-llm.") + if load_in_low_bit is not None: + invalidInputError(load_in_low_bit == "asym_int4", + "You can only load awq model as aysm_int4 low bit type.") + + load_in_low_bit = "asym_int4" + + if int(awq_config.group_size) % get_ggml_qk_size(load_in_low_bit) != 0: + invalidInputError(False, + (f"group_size must be divisible by " + f"{get_ggml_qk_size(load_in_low_bit)}.")) + + kwargs["quantization_config"] = awq_config # load int x-bit kwargs["low_cpu_mem_usage"] = True @@ -156,16 +202,63 @@ class _BaseAutoModelClass: # and lead to args missing. modules_to_not_convert = kwargs.pop("modules_to_not_convert", None) replace_embedding = kwargs.pop("replace_embedding", False) + quant_config = kwargs.pop("quantization_config", None) _args = copy.deepcopy(args) _kwargs = copy.deepcopy(kwargs) - try: - model = cls.HF_Model.from_pretrained(*args, **kwargs) - except NotImplementedError: - logger.info("Failed to load models with `low_cpu_mem_usage` specified, " - "will fall to traditional load method with higher memory consumption.") - _kwargs["low_cpu_mem_usage"] = False - model = cls.HF_Model.from_pretrained(*_args, **_kwargs) - model.config.update({"bigdl_lcmu_enabled": False}) + awq_config = None + + if quant_config and quant_config.quant_method == "awq": + # The latest transformers only support cuda version + # This load awq ckpt logic is copied from + # https://github.com/casper-hansen/AutoAWQ/blob/main/awq/models/base.py#L147 + from accelerate import init_empty_weights, infer_auto_device_map,\ + load_checkpoint_in_model + from bigdl.llm.transformers.awq.awq import _replace_with_awq_layers,\ + get_layer_type, _load_config + awq_config = quant_config + model_weights_path, config = _load_config(args[0], '', max_new_tokens=None, + safetensors=True) + with init_empty_weights(): + model = cls.HF_Model.from_config(config=config, trust_remote_code=True) + + _replace_with_awq_layers(model, awq_config=awq_config) + + model.tie_weights() + + # Get device map + device_map = infer_auto_device_map( + model, + no_split_module_classes=[get_layer_type(config)], + max_memory=None, + dtype=config.torch_dtype + ) + + # Load checkpoint + load_checkpoint_in_model( + model, + checkpoint=model_weights_path, + device_map=device_map, + offload_folder=None, + dtype=config.torch_dtype + ) + + # Offloading dispatch + from accelerate import dispatch_model + model = dispatch_model( + model, + device_map=device_map, + offload_dir=None + ) + else: + try: + model = cls.HF_Model.from_pretrained(*args, **kwargs) + except NotImplementedError: + logger.info("Failed to load models with `low_cpu_mem_usage` specified, " + "will fall to traditional load method with higher memory consumption.") + _kwargs["low_cpu_mem_usage"] = False + model = cls.HF_Model.from_pretrained(*_args, **_kwargs) + model.config.update({"bigdl_lcmu_enabled": False}) + model = model.to("cpu") model = ggml_convert_low_bit(model, qtype, optimize_model, modules_to_not_convert=modules_to_not_convert,