From 404e101dedb3818f169dc129483a86cde221e930 Mon Sep 17 00:00:00 2001 From: Yina Chen <33650826+cyita@users.noreply.github.com> Date: Wed, 6 Dec 2023 15:36:21 +0800 Subject: [PATCH] QALora example (#9551) * Support qa-lora * init * update * update * update * update * update * update merge * update * fix style & update scripts * update * address comments * fix typo * fix typo --------- Co-authored-by: Yang Wang --- .../QLoRA-FineTuning/alpaca-qlora/README.md | 48 +++++++--- .../alpaca-qlora/alpaca_qlora_finetuning.py | 17 ++-- .../qalora_finetune_llama2_7b_arc_1_card.sh | 29 +++++++ .../qalora_finetune_llama2_7b_arc_2_card.sh | 34 ++++++++ ...lora_finetune_llama2_7b_pvc_1550_1_card.sh | 34 ++++++++ ...lora_finetune_llama2_7b_pvc_1550_1_tile.sh | 31 +++++++ .../QLoRA-FineTuning/export_merged_model.py | 87 +++++++++++++------ .../GPU/QLoRA-FineTuning/qlora_finetuning.py | 4 +- .../llm/src/bigdl/llm/transformers/convert.py | 4 +- .../bigdl/llm/transformers/low_bit_linear.py | 6 +- .../llm/src/bigdl/llm/transformers/model.py | 10 +-- .../llm/src/bigdl/llm/transformers/qlora.py | 31 ++++++- 12 files changed, 281 insertions(+), 54 deletions(-) create mode 100644 python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/qalora_finetune_llama2_7b_arc_1_card.sh create mode 100644 python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/qalora_finetune_llama2_7b_arc_2_card.sh create mode 100644 python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/qalora_finetune_llama2_7b_pvc_1550_1_card.sh create mode 100644 python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/qalora_finetune_llama2_7b_pvc_1550_1_tile.sh diff --git a/python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/README.md b/python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/README.md index b8695f74..bc8ca67e 100644 --- a/python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/README.md +++ b/python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/README.md @@ -1,6 +1,6 @@ -# Alpaca QLoRA Finetuning (experimental support) +# Alpaca QLoRA & QA-LoRA Finetuning (experimental support) -This example ports [Alpaca-LoRA](https://github.com/tloen/alpaca-lora/tree/main) to BigDL-LLM QLoRA on [Intel GPUs](../../README.md). +This example ports [Alpaca-LoRA](https://github.com/tloen/alpaca-lora/tree/main) to BigDL-LLM (using either [QLoRA](https://arxiv.org/abs/2305.14314) or [QA-LoRA](https://arxiv.org/abs/2309.14717) algorithm) on [Intel GPU](../../README.md). ### 0. Requirements To run this example with BigDL-LLM on Intel GPUs, we have some recommended requirements for your machine, please refer to [here](../../README.md#requirements) for more information. @@ -28,54 +28,75 @@ source /opt/intel/oneapi/setvars.sh Here, we provide example usages on different hardware. Please refer to the appropriate script based on your device: -#### Finetuning LLaMA2-7B on single Arc A770 +#### QLoRA + +##### Finetuning LLaMA2-7B on single Arc A770 ```bash bash finetune_llama2_7b_arc_1_card.sh ``` -#### Finetuning LLaMA2-7B on two Arc A770 +##### Finetuning LLaMA2-7B on two Arc A770 ```bash bash finetune_llama2_7b_arc_2_card.sh ``` -#### Finetuning LLaMA2-7B on single Data Center GPU Flex 170 +##### Finetuning LLaMA2-7B on single Data Center GPU Flex 170 ```bash bash finetune_llama2_7b_flex_170_1_card.sh ``` -#### Finetuning LLaMA2-7B on three Data Center GPU Flex 170 +##### Finetuning LLaMA2-7B on three Data Center GPU Flex 170 ```bash bash finetune_llama2_7b_flex_170_3_card.sh ``` -#### Finetuning LLaMA2-7B on single Intel Data Center GPU Max 1100 +##### Finetuning LLaMA2-7B on single Intel Data Center GPU Max 1100 ```bash bash finetune_llama2_7b_pvc_1100_1_card.sh ``` -#### Finetuning LLaMA2-7B on four Intel Data Center GPU Max 1100 +##### Finetuning LLaMA2-7B on four Intel Data Center GPU Max 1100 ```bash bash finetune_llama2_7b_pvc_1100_4_card.sh ``` -#### Finetuning LLaMA2-7B on single Intel Data Center GPU Max 1550 +##### Finetuning LLaMA2-7B on single Intel Data Center GPU Max 1550 ```bash bash finetune_llama2_7b_pvc_1550_1_card.sh ``` -#### Finetuning LLaMA2-7B on four Intel Data Center GPU Max 1550 +##### Finetuning LLaMA2-7B on four Intel Data Center GPU Max 1550 ```bash bash finetune_llama2_7b_pvc_1550_4_card.sh ``` +#### QA-LoRA +##### Finetuning LLaMA2-7B on single Arc A770 + +```bash +bash qalora_finetune_llama2_7b_arc_1_card.sh +``` + +##### Finetuning LLaMA2-7B on two Arc A770 + +```bash +bash qalora_finetune_llama2_7b_arc_2_card.sh +``` + +##### Finetuning LLaMA2-7B on single Tile Intel Data Center GPU Max 1550 + +```bash +bash qalora_finetune_llama2_7b_pvc_1550_1_tile.sh +``` + **Important: If you fail to complete the whole finetuning process, it is suggested to resume training from a previously saved checkpoint by specifying `resume_from_checkpoint` to the local checkpoint folder as following:** ```bash python ./alpaca_qlora_finetuning.py \ @@ -97,3 +118,10 @@ python ./alpaca_qlora_finetuning.py \ {'loss': 1.8552, 'learning_rate': 2.9996503623845395e-05, 'epoch': 0.02} 1%|█ | 8/1164 [xx:xx training.log diff --git a/python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/qalora_finetune_llama2_7b_pvc_1550_1_card.sh b/python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/qalora_finetune_llama2_7b_pvc_1550_1_card.sh new file mode 100644 index 00000000..bba9a757 --- /dev/null +++ b/python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/qalora_finetune_llama2_7b_pvc_1550_1_card.sh @@ -0,0 +1,34 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +export MASTER_ADDR=127.0.0.1 +export OMP_NUM_THREADS=28 # adjust this to 1/4 of total physical cores +export FI_PROVIDER=tcp +export CCL_ATL_TRANSPORT=ofi + +mpirun -n 2 \ + python -u ./alpaca_qlora_finetuning.py \ + --base_model "meta-llama/Llama-2-7b-hf" \ + --data_path "yahma/alpaca-cleaned" \ + --output_dir "./bigdl-qlora-alpaca" \ + --qa_lora True \ + --learning_rate 9e-5 \ + --micro_batch_size 8 \ + --batch_size 128 \ + --lora_r 8 \ + --lora_alpha 16 \ + --lora_dropout 0.05 \ + --val_set_size 2000 > training.log \ No newline at end of file diff --git a/python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/qalora_finetune_llama2_7b_pvc_1550_1_tile.sh b/python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/qalora_finetune_llama2_7b_pvc_1550_1_tile.sh new file mode 100644 index 00000000..eae51ea6 --- /dev/null +++ b/python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/qalora_finetune_llama2_7b_pvc_1550_1_tile.sh @@ -0,0 +1,31 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# You could also specify `--base_model` to the local path of the huggingface model checkpoint folder and `--data_path` to the local path of the dataset JSON file + +python ./alpaca_qlora_finetuning.py \ + --base_model "meta-llama/Llama-2-7b-hf" \ + --data_path "yahma/alpaca-cleaned" \ + --output_dir "./bigdl-qlora-alpaca" \ + --learning_rate 9e-5 \ + --micro_batch_size 8 \ + --batch_size 128 \ + --gradient_checkpointing False \ + --lora_r 8 \ + --lora_alpha 16 \ + --lora_dropout 0.05 \ + --val_set_size 2000 \ + --qa_lora True \ No newline at end of file diff --git a/python/llm/example/GPU/QLoRA-FineTuning/export_merged_model.py b/python/llm/example/GPU/QLoRA-FineTuning/export_merged_model.py index 97079671..06792386 100644 --- a/python/llm/example/GPU/QLoRA-FineTuning/export_merged_model.py +++ b/python/llm/example/GPU/QLoRA-FineTuning/export_merged_model.py @@ -31,11 +31,13 @@ import os import torch -import transformers from transformers import LlamaTokenizer # noqa: F402 -from bigdl.llm.transformers.qlora import PeftModel +from bigdl.llm.transformers.qlora import PeftModel, LoraConfig from bigdl.llm.transformers import AutoModelForCausalLM +from bigdl.llm.transformers.low_bit_linear import get_block_size import argparse +import tempfile +import shutil if __name__ == "__main__": @@ -51,31 +53,66 @@ if __name__ == "__main__": adapter_path = args.adapter_path tokenizer = LlamaTokenizer.from_pretrained(base_model) - base_model = AutoModelForCausalLM.from_pretrained( - base_model, - # load_in_low_bit="nf4", # should load the orignal model - torch_dtype=torch.float16, - device_map={"": "cpu"}, - ) + lora_config = LoraConfig.from_json_file(os.path.join(adapter_path, "adapter_config.json")) + qa_lora = lora_config.get("qa_lora", False) - lora_model = PeftModel.from_pretrained( - base_model, - adapter_path, - device_map={"": "cpu"}, - torch_dtype=torch.float16, - ) + temp_dir = None + if qa_lora: + # Convert the qa-lora adapter to the correct shapes + # The default 4-bit format for qa_lora is sym_int4 + block_size = get_block_size("sym_int4") + temp_dir = tempfile.TemporaryDirectory() + tmpdirname = os.path.join(temp_dir.name, "adapter") + try: + shutil.copytree(adapter_path, tmpdirname) + except Exception as e: + print(f"Failed to copy adapter dir, error: {e}") + mid_lora_path = os.path.join(tmpdirname, "adapter_model.bin") - # merge weights - new merging method from peft - lora_model = lora_model.merge_and_unload() + adapter_path = os.path.join(adapter_path, "adapter_model.bin") - lora_model.train(False) + lora = torch.load(adapter_path, map_location='cpu') + # Get lora_a names + tmp_keys = [key for key in lora.keys() if 'lora_A' in key] - lora_model_sd = lora_model.state_dict() - deloreanized_sd = { - k.replace("base_model.model.", ""): v - for k, v in lora_model_sd.items() - if "lora" not in k - } + for tmp_key in tmp_keys: + lora_a = lora[tmp_key] / block_size + lora[tmp_key] = torch.repeat_interleave(lora_a, block_size, dim=1) - base_model.save_pretrained(args.output_path, state_dict=deloreanized_sd) - tokenizer.save_pretrained(args.output_path) + torch.save(lora, mid_lora_path) + adapter_path = tmpdirname + + try: + base_model = AutoModelForCausalLM.from_pretrained( + base_model, + # load_in_low_bit="nf4", # should load the orignal model + torch_dtype=torch.float16, + device_map={"": "cpu"}, + ) + + lora_model = PeftModel.from_pretrained( + base_model, + adapter_path, + device_map={"": "cpu"}, + torch_dtype=torch.float16, + ) + + # merge weights - new merging method from peft + lora_model = lora_model.merge_and_unload() + + lora_model.train(False) + + lora_model_sd = lora_model.state_dict() + deloreanized_sd = { + k.replace("base_model.model.", ""): v + for k, v in lora_model_sd.items() + if "lora" not in k + } + + base_model.save_pretrained(args.output_path, state_dict=deloreanized_sd) + tokenizer.save_pretrained(args.output_path) + except Exception as e: + print(f"Failed to merge the adapter, error: {e}.") + finally: + if qa_lora and temp_dir: + temp_dir.cleanup() diff --git a/python/llm/example/GPU/QLoRA-FineTuning/qlora_finetuning.py b/python/llm/example/GPU/QLoRA-FineTuning/qlora_finetuning.py index 06cca4ea..c2044f26 100644 --- a/python/llm/example/GPU/QLoRA-FineTuning/qlora_finetuning.py +++ b/python/llm/example/GPU/QLoRA-FineTuning/qlora_finetuning.py @@ -19,9 +19,9 @@ import os import transformers from transformers import LlamaTokenizer -from peft import LoraConfig import intel_extension_for_pytorch as ipex -from bigdl.llm.transformers.qlora import get_peft_model, prepare_model_for_kbit_training +from bigdl.llm.transformers.qlora import get_peft_model, prepare_model_for_kbit_training,\ + LoraConfig from bigdl.llm.transformers import AutoModelForCausalLM from datasets import load_dataset import argparse diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index 3ba482cd..2df8811a 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -109,8 +109,8 @@ def is_linear_module(module): def convert_gptq(module, awq=False): - from bigdl.llm.transformers.low_bit_linear import get_ggml_qk_size - Q4_1 = get_ggml_qk_size("asym_int4") + from bigdl.llm.transformers.low_bit_linear import get_block_size + Q4_1 = get_block_size("asym_int4") scales = module.scales diff --git a/python/llm/src/bigdl/llm/transformers/low_bit_linear.py b/python/llm/src/bigdl/llm/transformers/low_bit_linear.py index a95623ff..6fb465c6 100644 --- a/python/llm/src/bigdl/llm/transformers/low_bit_linear.py +++ b/python/llm/src/bigdl/llm/transformers/low_bit_linear.py @@ -71,10 +71,14 @@ MOFQ4 = ggml_tensor_qtype["mixed_fp4"] MOFQ8 = ggml_tensor_qtype["mixed_fp8"] -def get_ggml_qk_size(qtype: str): +def get_block_size(qtype: str): return ggml.ggml_qk_size(ggml_tensor_qtype[qtype]) +def get_qk_size(qtype: int): + return ggml.ggml_qk_size(qtype) + + def ggml_convert_qtype(tensor: torch.Tensor, qtype: int, device=None, convert_shape_only=False): QK = ggml.ggml_qk_size(qtype) diff --git a/python/llm/src/bigdl/llm/transformers/model.py b/python/llm/src/bigdl/llm/transformers/model.py index 19fe31da..beb6a2c3 100644 --- a/python/llm/src/bigdl/llm/transformers/model.py +++ b/python/llm/src/bigdl/llm/transformers/model.py @@ -124,7 +124,7 @@ class _BaseAutoModelClass: if load_in_4bit or load_in_low_bit: if config_dict.get("quantization_config", None) is not None: - from bigdl.llm.transformers.low_bit_linear import get_ggml_qk_size + from bigdl.llm.transformers.low_bit_linear import get_block_size q_config = config_dict["quantization_config"] if q_config["quant_method"] == "gptq": invalidInputError(q_config["bits"] == 4, @@ -136,10 +136,10 @@ class _BaseAutoModelClass: "You can only load gptq model as aysm_int4 low bit type.") load_in_low_bit = "asym_int4" - if int(q_config["group_size"]) % get_ggml_qk_size(load_in_low_bit) != 0: + if int(q_config["group_size"]) % get_block_size(load_in_low_bit) != 0: invalidInputError(False, (f"group_size must be divisible by " - f"{get_ggml_qk_size(load_in_low_bit)}.")) + f"{get_block_size(load_in_low_bit)}.")) if user_quantization_config is not None: invalidInputError(user_quantization_config.bits == 4, "Only 4-bit gptq is supported in bigdl-llm.") @@ -166,10 +166,10 @@ class _BaseAutoModelClass: load_in_low_bit = "asym_int4" - if int(awq_config.group_size) % get_ggml_qk_size(load_in_low_bit) != 0: + if int(awq_config.group_size) % get_block_size(load_in_low_bit) != 0: invalidInputError(False, (f"group_size must be divisible by " - f"{get_ggml_qk_size(load_in_low_bit)}.")) + f"{get_block_size(load_in_low_bit)}.")) kwargs["quantization_config"] = awq_config diff --git a/python/llm/src/bigdl/llm/transformers/qlora.py b/python/llm/src/bigdl/llm/transformers/qlora.py index d9fee5a3..53916ed9 100644 --- a/python/llm/src/bigdl/llm/transformers/qlora.py +++ b/python/llm/src/bigdl/llm/transformers/qlora.py @@ -49,7 +49,7 @@ # limitations under the License. import torch -from bigdl.llm.transformers.low_bit_linear import LowBitLinear +from bigdl.llm.transformers.low_bit_linear import LowBitLinear, get_qk_size from peft.tuners.lora import LoraLayer from bigdl.llm.utils.common import invalidInputError from bigdl.llm.transformers.utils import get_autocast_dtype @@ -66,6 +66,7 @@ class LoraLowBitLinear(LowBitLinear, LoraLayer): r: int = 0, lora_alpha: int = 1, lora_dropout: float = 0.0, + qa_lora: bool = True, **kwargs, ): LowBitLinear.__init__( @@ -76,7 +77,10 @@ class LoraLowBitLinear(LowBitLinear, LoraLayer): bias=kwargs.get("bias", True), conver_to_half=False, ) - LoraLayer.__init__(self, in_features=in_features, out_features=out_features) + + qk_size = get_qk_size(kwargs.get("qtype")) + lora_in_features = in_features // qk_size if qa_lora else in_features + LoraLayer.__init__(self, in_features=lora_in_features, out_features=out_features) # Freezing the pre-trained weight matrix self.weight.requires_grad = False @@ -84,6 +88,10 @@ class LoraLowBitLinear(LowBitLinear, LoraLayer): init_lora_weights = kwargs.pop("init_lora_weights", True) self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights) self.active_adapter = adapter_name + if qa_lora: + self.qa_pool = torch.nn.AvgPool1d(qk_size) + else: + self.qa_pool = torch.nn.Identity() def forward(self, x: torch.Tensor): autocast_dtype = get_autocast_dtype(x) @@ -103,14 +111,16 @@ class LoraLowBitLinear(LowBitLinear, LoraLayer): x = x.to(self.lora_A[self.active_adapter].weight.dtype) output = ( self.lora_B[self.active_adapter]( - self.lora_A[self.active_adapter](self.lora_dropout[self.active_adapter](x)) + self.lora_A[self.active_adapter]( + self.lora_dropout[self.active_adapter](self.qa_pool(x))) ).to(expected_dtype) * self.scaling[self.active_adapter] ) else: output = ( self.lora_B[self.active_adapter]( - self.lora_A[self.active_adapter](self.lora_dropout[self.active_adapter](x)) + self.lora_A[self.active_adapter]( + self.lora_dropout[self.active_adapter](self.qa_pool(x))) ) * self.scaling[self.active_adapter] ) @@ -126,6 +136,7 @@ def _create_new_module(create_new_module_func, lora_config, adapter_name, target low_bit_kwargs.update( { "qtype": target.qtype, + "qa_lora": lora_config.qa_lora, } ) new_module = LoraLowBitLinear(adapter_name, @@ -140,6 +151,14 @@ def _create_new_module(create_new_module_func, lora_config, adapter_name, target from peft.tuners.lora import LoraModel +from peft.tuners.lora import LoraConfig as LoraConfigBase +from dataclasses import dataclass, field + + +@dataclass +class LoraConfig(LoraConfigBase): + + qa_lora: bool = field(default=False, metadata={"help": "enable qa-lora"}) def get_peft_model(*args, **kwargs): @@ -357,6 +376,10 @@ def _setup_devices(self) -> "torch.device": torch.cuda.set_device(device) return device +from peft.mapping import PEFT_TYPE_TO_CONFIG_MAPPING + +PEFT_TYPE_TO_CONFIG_MAPPING["lora"] = LoraConfig + # workaround a IPEX bug that prevents resume training in bf16 from accelerate import Accelerator Accelerator._prepare_ipex = patch_prepare_ipex