From 2f36769208d8c3f7b3cc8a19a88e41e6496f50a8 Mon Sep 17 00:00:00 2001 From: Ruonan Wang Date: Fri, 22 Dec 2023 11:05:39 +0800 Subject: [PATCH] LLM: bigdl-llm lora support & lora example (#9740) * lora support and single card example * support multi-card, refactor code * fix model id and style * remove torch patch, add two new class for bf16, update example * fix style * change to training_mode * small fix * add more info in help * fixstyle, update readme * fix ut * fix ut * Handling compatibility issues with default LoraConfig --- .../QLoRA-FineTuning/alpaca-qlora/README.md | 26 +- .../alpaca-qlora/alpaca_qlora_finetuning.py | 24 +- ...lora_finetune_llama2_7b_pvc_1110_4_card.sh | 31 +++ ...lora_finetune_llama2_7b_pvc_1550_1_tile.sh | 26 ++ ...lora_finetune_llama2_7b_pvc_1550_4_card.sh | 31 +++ .../qalora_finetune_llama2_7b_arc_1_card.sh | 2 +- .../qalora_finetune_llama2_7b_arc_2_card.sh | 2 +- ...lora_finetune_llama2_7b_pvc_1550_1_card.sh | 2 +- ...lora_finetune_llama2_7b_pvc_1550_1_tile.sh | 2 +- .../GPU/QLoRA-FineTuning/qlora_finetuning.py | 2 +- python/llm/src/bigdl/llm/ggml/quantize.py | 3 +- .../llm/src/bigdl/llm/transformers/convert.py | 21 +- .../bigdl/llm/transformers/low_bit_linear.py | 34 +++ .../llm/src/bigdl/llm/transformers/model.py | 4 +- .../llm/src/bigdl/llm/transformers/qlora.py | 259 +++++++----------- .../bigdl/llm/transformers/training_patch.py | 198 +++++++++++++ 16 files changed, 481 insertions(+), 186 deletions(-) create mode 100644 python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/lora_finetune_llama2_7b_pvc_1110_4_card.sh create mode 100644 python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/lora_finetune_llama2_7b_pvc_1550_1_tile.sh create mode 100644 python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/lora_finetune_llama2_7b_pvc_1550_4_card.sh create mode 100644 python/llm/src/bigdl/llm/transformers/training_patch.py diff --git a/python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/README.md b/python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/README.md index 7386602c..aab2f449 100644 --- a/python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/README.md +++ b/python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/README.md @@ -1,6 +1,6 @@ -# Alpaca QLoRA & QA-LoRA Finetuning (experimental support) +# Alpaca Finetuning with BigDL-LLM -This example ports [Alpaca-LoRA](https://github.com/tloen/alpaca-lora/tree/main) to BigDL-LLM (using either [QLoRA](https://arxiv.org/abs/2305.14314) or [QA-LoRA](https://arxiv.org/abs/2309.14717) algorithm) on [Intel GPU](../../README.md). +This example ports [Alpaca-LoRA](https://github.com/tloen/alpaca-lora/tree/main) to BigDL-LLM (using either [QLoRA](https://arxiv.org/abs/2305.14314) / [QA-LoRA](https://arxiv.org/abs/2309.14717) or [LoRA](https://arxiv.org/abs/2106.09685) algorithm) on [Intel GPU](../../README.md). ### 0. Requirements To run this example with BigDL-LLM on Intel GPUs, we have some recommended requirements for your machine, please refer to [here](../../README.md#requirements) for more information. @@ -26,6 +26,8 @@ source /opt/intel/oneapi/setvars.sh ### 3. Finetune +Now we support three training modes ([QLoRA](https://arxiv.org/abs/2305.14314) / [QA-LoRA](https://arxiv.org/abs/2309.14717) / [LoRA](https://arxiv.org/abs/2106.09685)), to run different mode, just change `training_mode` to `qlora` / `qalora` / `lora` in below script. + Here, we provide example usages on different hardware. Please refer to the appropriate script based on your device: #### QLoRA @@ -97,6 +99,26 @@ bash qalora_finetune_llama2_7b_arc_2_card.sh bash qalora_finetune_llama2_7b_pvc_1550_1_tile.sh ``` +#### LoRA + +##### Finetuning LLaMA2-7B on four Intel Data Center GPU Max 1100 + +```bash +bash lora_finetune_llama2_7b_pvc_1100_1_card.sh +``` + +##### Finetuning LLaMA2-7B on single Tile Intel Data Center GPU Max 1550 + +```bash +bash lora_finetune_llama2_7b_pvc_1550_1_tile.sh +``` + +##### Finetuning LLaMA2-7B on four Intel Data Center GPU Max 1550 + +```bash +bash lora_finetune_llama2_7b_pvc_1550_4_card.sh +``` + ### 4. (Optional) Resume Training If you fail to complete the whole finetuning process, it is suggested to resume training from a previously saved checkpoint by specifying `resume_from_checkpoint` to the local checkpoint folder as following:** ```bash diff --git a/python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/alpaca_qlora_finetuning.py b/python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/alpaca_qlora_finetuning.py index a7bbf22a..c822bd57 100644 --- a/python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/alpaca_qlora_finetuning.py +++ b/python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/alpaca_qlora_finetuning.py @@ -48,10 +48,11 @@ from utils.prompter import Prompter import intel_extension_for_pytorch as ipex from bigdl.llm.transformers import AutoModelForCausalLM - # import them from bigdl.llm.transformers.qlora to get a BigDL-LLM compatible Peft model from bigdl.llm.transformers.qlora import get_peft_model, prepare_model_for_kbit_training,\ - cast_lora_weight, LoraConfig + LoraConfig +from bigdl.llm.utils.common import invalidInputError + def get_int_from_env(env_keys, default): """Returns the first positive env value found in the `env_keys` list or the default.""" @@ -109,8 +110,10 @@ def train( prompt_template_name: str = "alpaca", # The prompt template to use, will default to alpaca. gradient_checkpointing: bool = False, deepspeed: str = None, - qa_lora: bool = False, # if True, use qa-lora https://arxiv.org/abs/2309.14717 + training_mode: str = "qlora", ): + invalidInputError(training_mode in ["qlora", "qalora", "lora"], + "Only qlora / qalora / lora are supported for training_mode now.") if int(os.environ.get("LOCAL_RANK", 0)) == 0: print( f"Training Alpaca-LoRA model with params:\n" @@ -136,7 +139,7 @@ def train( f"wandb_log_model: {wandb_log_model}\n" f"resume_from_checkpoint: {resume_from_checkpoint or False}\n" f"prompt template: {prompt_template_name}\n" - f"qa_lora: {qa_lora}\n" + f"training_mode: {training_mode}\n" ) assert ( base_model @@ -175,7 +178,12 @@ def train( else: # According to the QLoRA paper, using "nf4" could yield better model quality than "int4" # Default 4-bit format for qa-lora is sym_int4 - low_bit_format = "sym_int4" if qa_lora else "nf4" + if training_mode == "qalora": + low_bit_format = "sym_int4" + elif training_mode == "lora": + low_bit_format = "bf16" + else: + low_bit_format = "nf4" # Load the base model from a directory or the HF Hub to 4-bit format model = AutoModelForCausalLM.from_pretrained( base_model, @@ -196,7 +204,7 @@ def train( 0 # unk. we want this to be different from the eos token ) tokenizer.padding_side = "left" # Allow batched inference - + print(model) def tokenize(prompt, add_eos_token=True): @@ -257,7 +265,7 @@ def train( lora_dropout=lora_dropout, bias="none", task_type="CAUSAL_LM", - qa_lora=qa_lora, + training_mode=training_mode, ) print(f"Lora Config: {config}") model = get_peft_model(model, config) @@ -301,7 +309,7 @@ def train( max_grad_norm=0.3, num_train_epochs=num_epochs, learning_rate=learning_rate, - lr_scheduler_type="constant" if qa_lora else "cosine", + lr_scheduler_type="constant" if training_mode=="qalora" else "cosine", bf16=True, # ensure training more stable logging_steps=1, optim="adamw_torch", diff --git a/python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/lora_finetune_llama2_7b_pvc_1110_4_card.sh b/python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/lora_finetune_llama2_7b_pvc_1110_4_card.sh new file mode 100644 index 00000000..473a4ec1 --- /dev/null +++ b/python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/lora_finetune_llama2_7b_pvc_1110_4_card.sh @@ -0,0 +1,31 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +export MASTER_ADDR=127.0.0.1 +export OMP_NUM_THREADS=14 +export FI_PROVIDER=tcp +export CCL_ATL_TRANSPORT=ofi + +mpirun -n 4 \ + python -u ./alpaca_qlora_finetuning.py \ + --micro_batch_size 8 \ + --batch_size 128 \ + --base_model "meta-llama/Llama-2-7b-hf" \ + --data_path "yahma/alpaca-cleaned" \ + --output_dir "./bigdl-lora-alpaca" \ + --gradient_checkpointing True \ + --lora_target_modules "['k_proj', 'q_proj', 'o_proj', 'v_proj', 'up_proj', 'down_proj', 'gate_proj']" \ + --training_mode "lora" diff --git a/python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/lora_finetune_llama2_7b_pvc_1550_1_tile.sh b/python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/lora_finetune_llama2_7b_pvc_1550_1_tile.sh new file mode 100644 index 00000000..0c61ba22 --- /dev/null +++ b/python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/lora_finetune_llama2_7b_pvc_1550_1_tile.sh @@ -0,0 +1,26 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# You could also specify `--base_model` to the local path of the huggingface model checkpoint folder and `--data_path` to the local path of the dataset JSON file +python ./alpaca_qlora_finetuning.py \ + --micro_batch_size 8 \ + --batch_size 128 \ + --base_model "meta-llama/Llama-2-7b-hf" \ + --data_path "yahma/alpaca-cleaned" \ + --output_dir "./bigdl-lora-alpaca" \ + --gradient_checkpointing True \ + --lora_target_modules "['k_proj', 'q_proj', 'o_proj', 'v_proj', 'up_proj', 'down_proj', 'gate_proj']" \ + --training_mode "lora" diff --git a/python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/lora_finetune_llama2_7b_pvc_1550_4_card.sh b/python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/lora_finetune_llama2_7b_pvc_1550_4_card.sh new file mode 100644 index 00000000..57dc3719 --- /dev/null +++ b/python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/lora_finetune_llama2_7b_pvc_1550_4_card.sh @@ -0,0 +1,31 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +export MASTER_ADDR=127.0.0.1 +export OMP_NUM_THREADS=7 +export FI_PROVIDER=tcp +export CCL_ATL_TRANSPORT=ofi + +mpirun -n 8 \ + python -u ./alpaca_qlora_finetuning.py \ + --micro_batch_size 8 \ + --batch_size 128 \ + --base_model "meta-llama/Llama-2-7b-hf" \ + --data_path "yahma/alpaca-cleaned" \ + --output_dir "./bigdl-lora-alpaca" \ + --gradient_checkpointing False \ + --lora_target_modules "['k_proj', 'q_proj', 'o_proj', 'v_proj', 'up_proj', 'down_proj', 'gate_proj']" \ + --training_mode "lora" diff --git a/python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/qalora_finetune_llama2_7b_arc_1_card.sh b/python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/qalora_finetune_llama2_7b_arc_1_card.sh index 72cf3bbc..ae4a726d 100644 --- a/python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/qalora_finetune_llama2_7b_arc_1_card.sh +++ b/python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/qalora_finetune_llama2_7b_arc_1_card.sh @@ -26,4 +26,4 @@ python ./alpaca_qlora_finetuning.py \ --lora_alpha 16 \ --lora_dropout 0.05 \ --val_set_size 2000 \ - --qa_lora True + --training_mode "qalora" diff --git a/python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/qalora_finetune_llama2_7b_arc_2_card.sh b/python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/qalora_finetune_llama2_7b_arc_2_card.sh index ba4c9562..c1adcb11 100644 --- a/python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/qalora_finetune_llama2_7b_arc_2_card.sh +++ b/python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/qalora_finetune_llama2_7b_arc_2_card.sh @@ -31,4 +31,4 @@ mpirun -n 2 \ --lora_alpha 16 \ --lora_dropout 0.05 \ --val_set_size 2000 \ - --qa_lora True > training.log + --training_mode "qalora" > training.log diff --git a/python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/qalora_finetune_llama2_7b_pvc_1550_1_card.sh b/python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/qalora_finetune_llama2_7b_pvc_1550_1_card.sh index bba9a757..f2e2cbaf 100644 --- a/python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/qalora_finetune_llama2_7b_pvc_1550_1_card.sh +++ b/python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/qalora_finetune_llama2_7b_pvc_1550_1_card.sh @@ -24,7 +24,7 @@ mpirun -n 2 \ --base_model "meta-llama/Llama-2-7b-hf" \ --data_path "yahma/alpaca-cleaned" \ --output_dir "./bigdl-qlora-alpaca" \ - --qa_lora True \ + --training_mode "qalora" \ --learning_rate 9e-5 \ --micro_batch_size 8 \ --batch_size 128 \ diff --git a/python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/qalora_finetune_llama2_7b_pvc_1550_1_tile.sh b/python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/qalora_finetune_llama2_7b_pvc_1550_1_tile.sh index eae51ea6..e1da7d2c 100644 --- a/python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/qalora_finetune_llama2_7b_pvc_1550_1_tile.sh +++ b/python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/qalora_finetune_llama2_7b_pvc_1550_1_tile.sh @@ -28,4 +28,4 @@ python ./alpaca_qlora_finetuning.py \ --lora_alpha 16 \ --lora_dropout 0.05 \ --val_set_size 2000 \ - --qa_lora True \ No newline at end of file + --training_mode "qalora" \ No newline at end of file diff --git a/python/llm/example/GPU/QLoRA-FineTuning/qlora_finetuning.py b/python/llm/example/GPU/QLoRA-FineTuning/qlora_finetuning.py index 88dea463..21dbeaad 100644 --- a/python/llm/example/GPU/QLoRA-FineTuning/qlora_finetuning.py +++ b/python/llm/example/GPU/QLoRA-FineTuning/qlora_finetuning.py @@ -57,7 +57,7 @@ if __name__ == "__main__": target_modules=["q_proj", "k_proj", "v_proj"], lora_dropout=0.05, bias="none", - task_type="CAUSAL_LM" + task_type="CAUSAL_LM", ) model = get_peft_model(model, config) diff --git a/python/llm/src/bigdl/llm/ggml/quantize.py b/python/llm/src/bigdl/llm/ggml/quantize.py index 607ee81a..185d31a7 100644 --- a/python/llm/src/bigdl/llm/ggml/quantize.py +++ b/python/llm/src/bigdl/llm/ggml/quantize.py @@ -38,7 +38,8 @@ ggml_tensor_qtype = {"sym_int4": 2, # q4_0 in ggml "mixed_fp4": 17, # Mixture of Formats Quantization 4 bits "mixed_fp8": 18, # Mixture of Formats Quantization 8 bits "fp8_e5m2": 19, # fp8 in e5m2 format - "fp8": 15} # fp8 in e4m3 format + "fp8": 15, # fp8 in e4m3 format + "bf16": 20} _llama_quantize_type = {"q4_0": 2, "q4_1": 3, diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index 1c1f0b84..9fe3f3ba 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -172,7 +172,8 @@ def convert_gptq(module, awq=False): def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, current_key_name=None, convert_shape_only=False, cpu_embedding=False): - from bigdl.llm.transformers.low_bit_linear import LowBitLinear, FP4Params, FP16Linear + from bigdl.llm.transformers.low_bit_linear import LowBitLinear, FP4Params, \ + FP16Linear, BF16Linear from bigdl.llm.transformers.embedding import LLMEmbedding has_been_replaced = False @@ -212,7 +213,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, if has_bias: new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\ .to(device_type) - elif qtype != ggml_tensor_qtype["fp16"]: + elif qtype not in [ggml_tensor_qtype["fp16"], ggml_tensor_qtype["bf16"]]: new_linear = LowBitLinear( in_features, out_features, @@ -233,7 +234,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, if module.bias is not None: new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\ .to(device_type) - else: + elif qtype == ggml_tensor_qtype["fp16"]: # only support two size now # may generalize to other sizes if module.in_features in [4096, 11008]: @@ -259,8 +260,20 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, if module.bias is not None: new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\ .to(device_type) + elif qtype == ggml_tensor_qtype["bf16"]: + new_linear = BF16Linear( + in_features, + out_features, + module.bias is not None, + mp_group=mp_group, + ) + device_type = module.weight.data.device.type + # convert here + new_linear._parameters['weight'] = nn.Parameter(module.weight) + if module.bias is not None: + new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\ + .to(device_type) - # fp16 may generalize to other sizes later if new_linear is not None: if not module.training: new_linear.eval() diff --git a/python/llm/src/bigdl/llm/transformers/low_bit_linear.py b/python/llm/src/bigdl/llm/transformers/low_bit_linear.py index b9558d50..a48e0f08 100644 --- a/python/llm/src/bigdl/llm/transformers/low_bit_linear.py +++ b/python/llm/src/bigdl/llm/transformers/low_bit_linear.py @@ -589,3 +589,37 @@ class FP16Linear(nn.Linear): result += self.bias return result.to(x.dtype) + + +class BF16Linear(nn.Linear): + def __init__(self, input_features, output_features, bias=True, + mp_group=None, compute_dtype=None): + super().__init__(input_features, output_features, bias) + self.in_len = input_features + self.out_len = output_features + self.weight_shape = (self.out_len, self.in_len) + self.weight_length = self.out_len * self.in_len + self.mp_group = mp_group + self.compute_dtype = compute_dtype + + def forward(self, x: torch.Tensor): + # only work for GPU now + invalidInputError(x.device.type == "xpu", + "bf16 only works for GPU now") + is_training = self.training and not torch.is_inference_mode_enabled() + if is_training: + # below logic is only for training + autocast_dtype = get_autocast_dtype(x) + if self.compute_dtype is not None and x.device.type == "xpu": + x = x.to(self.compute_dtype) # solve GC issue for unlora module + elif autocast_dtype is not None: + x = x.to(autocast_dtype) + + if self.bias is not None and self.bias.dtype != x.dtype: + self.bias.data = self.bias.data.to(x.dtype) + + result = F.linear(x, self.weight) + if self.bias is not None: + result += self.bias + + return result.to(x.dtype) diff --git a/python/llm/src/bigdl/llm/transformers/model.py b/python/llm/src/bigdl/llm/transformers/model.py index fa4cd59a..8c6930dc 100644 --- a/python/llm/src/bigdl/llm/transformers/model.py +++ b/python/llm/src/bigdl/llm/transformers/model.py @@ -106,7 +106,7 @@ class _BaseAutoModelClass: if the model is GPTQ model. Default to be False. :param load_in_low_bit: str value, options are sym_int4, asym_int4, sym_int5, asym_int5 - , sym_int8, nf3, nf4, fp4, fp8, fp8_e4m3, fp8_e5m2 or fp16. + , sym_int8, nf3, nf4, fp4, fp8, fp8_e4m3, fp8_e5m2, fp16 or bf16. sym_int4 means symmetric int 4, asym_int4 means asymmetric int 4, nf4 means 4-bit NormalFloat, etc. Relevant low bit optimizations will be applied to the model. @@ -231,7 +231,7 @@ class _BaseAutoModelClass: invalidInputError(q_k in ggml_tensor_qtype, f"Unknown load_in_low_bit value: {q_k}, expected:" f" sym_int4, asym_int4, sym_int5, asym_int5, sym_int8, nf3, nf4, " - "fp4, fp8, fp8_e4m3, fp8_e5m2, fp16, mixed_fp4 or mixed_fp8.") + "fp4, fp8, fp8_e4m3, fp8_e5m2, fp16, bf16, mixed_fp4 or mixed_fp8.") qtype = ggml_tensor_qtype[q_k] # In case it needs a second try, diff --git a/python/llm/src/bigdl/llm/transformers/qlora.py b/python/llm/src/bigdl/llm/transformers/qlora.py index 0c55c840..c4ab77bc 100644 --- a/python/llm/src/bigdl/llm/transformers/qlora.py +++ b/python/llm/src/bigdl/llm/transformers/qlora.py @@ -32,7 +32,7 @@ # limitations under the License. # # Some parts of this file is adapted from -# https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/training_args.py +# https://github.com/huggingface/peft/blob/v0.5.0/src/peft/tuners/lora.py # # Copyright 2020 The HuggingFace Team. All rights reserved. # @@ -49,11 +49,12 @@ # limitations under the License. import torch -from bigdl.llm.transformers.low_bit_linear import LowBitLinear, get_qk_size +from bigdl.llm.transformers.low_bit_linear import LowBitLinear, BF16Linear, get_qk_size from peft.tuners.lora import LoraLayer from bigdl.llm.utils.common import invalidInputError from bigdl.llm.transformers.utils import get_autocast_dtype import functools +from bigdl.llm.transformers import training_patch class LoraLowBitLinear(LowBitLinear, LoraLayer): @@ -128,22 +129,98 @@ class LoraLowBitLinear(LowBitLinear, LoraLayer): return result +class LoraBF16Linear(BF16Linear, LoraLayer): + # Lora implemented in a dense layer + def __init__( + self, + adapter_name, + in_features, + out_features, + r: int = 0, + lora_alpha: int = 1, + lora_dropout: float = 0.0, + **kwargs, + ): + BF16Linear.__init__( + self, + in_features, + out_features, + bias=kwargs.get("bias", True), + compute_dtype=torch.bfloat16, + ) + + LoraLayer.__init__(self, in_features=in_features, out_features=out_features) + + # Freezing the pre-trained weight matrix + self.weight.requires_grad = False + + init_lora_weights = kwargs.pop("init_lora_weights", True) + self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights) + self.active_adapter = adapter_name + + def forward(self, x: torch.Tensor): + autocast_dtype = get_autocast_dtype(x) + if x.device.type == "xpu": + # force to use bf16 on gpu + x = x.to(torch.bfloat16) + elif autocast_dtype is not None: + x = x.to(autocast_dtype) + result = super().forward(x) + + if self.disable_adapters or self.active_adapter not in self.lora_A.keys(): + return result + elif self.r[self.active_adapter] > 0: + result = result.clone() + if autocast_dtype is None and x.device.type == "cpu": + expected_dtype = result.dtype + x = x.to(self.lora_A[self.active_adapter].weight.dtype) + output = ( + self.lora_B[self.active_adapter]( + self.lora_A[self.active_adapter]( + self.lora_dropout[self.active_adapter](x)) + ).to(expected_dtype) + * self.scaling[self.active_adapter] + ) + else: + output = ( + self.lora_B[self.active_adapter]( + self.lora_A[self.active_adapter]( + self.lora_dropout[self.active_adapter](x)) + ) + * self.scaling[self.active_adapter] + ) + result += output + return result + + def _create_new_module(create_new_module_func, lora_config, adapter_name, target, **kwargs): - if isinstance(target, LowBitLinear): + if isinstance(target, LowBitLinear) or isinstance(target, BF16Linear): low_bit_kwargs = kwargs.copy() bias = low_bit_kwargs.pop("bias", False) - low_bit_kwargs.update( - { - "qtype": target.qtype, - "qa_lora": lora_config.qa_lora if hasattr(lora_config, "qa_lora") else False, - } - ) - new_module = LoraLowBitLinear(adapter_name, - target.in_features, - target.out_features, - bias=bias, - **low_bit_kwargs) + + if hasattr(lora_config, "training_mode") and lora_config.training_mode == "lora": + new_module = LoraBF16Linear(adapter_name, + target.in_features, + target.out_features, + bias=bias, + **low_bit_kwargs) + else: + if hasattr(lora_config, "training_mode"): + qa_lora = lora_config.training_mode == "qalora" + else: + qa_lora = False + low_bit_kwargs.update( + { + "qtype": target.qtype, + "qa_lora": qa_lora + } + ) + new_module = LoraLowBitLinear(adapter_name, + target.in_features, + target.out_features, + bias=bias, + **low_bit_kwargs) else: new_module = create_new_module_func(lora_config, adapter_name, target, **kwargs) @@ -157,8 +234,7 @@ from dataclasses import dataclass, field @dataclass class LoraConfig(LoraConfigBase): - - qa_lora: bool = field(default=False, metadata={"help": "enable qa-lora"}) + training_mode: str = field(default="qlora", metadata={"help": "determine training mode"}) def get_peft_model(*args, **kwargs): @@ -237,158 +313,10 @@ class PeftModel: return model -def patch_prepare_ipex(self, *args): - return tuple(args) - - -from transformers.utils import ( - requires_backends, - is_sagemaker_mp_enabled, - is_accelerate_available, - is_torch_xpu_available, - is_sagemaker_dp_enabled, - is_torch_tpu_available, - is_torch_npu_available) -from transformers.utils.generic import strtobool -from transformers.utils import cached_property -from transformers.training_args import logger, ParallelMode, DistributedType -import torch -import torch.distributed as dist -import os -import warnings -from datetime import timedelta - -if is_accelerate_available(): - from accelerate.state import AcceleratorState, PartialState - from accelerate.utils import DistributedType - -if is_sagemaker_mp_enabled(): - import smdistributed.modelparallel.torch as smp - - smp.init() - - -@cached_property -def _setup_devices(self) -> "torch.device": - requires_backends(self, ["torch"]) - logger.info("PyTorch: setting up devices") - if not is_sagemaker_mp_enabled(): - if not is_accelerate_available(min_version="0.20.1"): - invalidInputError( - False, - "Using the `Trainer` with `PyTorch` requires `accelerate>=0.20.1`: " - "Please run `pip install transformers[torch]` or `pip install accelerate -U`" - ) - AcceleratorState._reset_state(reset_partial_state=True) - self.distributed_state = None - if not self.use_ipex and "ACCELERATE_USE_IPEX" not in os.environ: - os.environ["ACCELERATE_USE_IPEX"] = "false" - if self.use_cpu or strtobool(os.environ.get("ACCELERATE_USE_CPU", "False")): - self.distributed_state = PartialState(cpu=True, backend=self.ddp_backend) - self._n_gpu = 0 - elif is_sagemaker_mp_enabled(): - local_rank = smp.local_rank() - device = torch.device("cuda", local_rank) - self._n_gpu = 1 - torch.cuda.set_device(device) - elif is_torch_xpu_available() and "ACCELERATE_USE_XPU" not in os.environ: - os.environ["ACCELERATE_USE_XPU"] = "true" - self.distributed_state = PartialState(timeout=timedelta(seconds=self.ddp_timeout)) - # device = torch.device("xpu:0") - device = self.distributed_state.device - self._n_gpu = 1 - elif is_sagemaker_dp_enabled(): - self.distributed_state = PartialState(_use_sagemaker_dp=True) - self._n_gpu = 1 - elif self.deepspeed: - # Need to do similar for Accelerator init - os.environ["ACCELERATE_USE_DEEPSPEED"] = "true" - self.distributed_state = PartialState(timeout=timedelta(seconds=self.ddp_timeout)) - del os.environ["ACCELERATE_USE_DEEPSPEED"] - self._n_gpu = 1 - else: - self.distributed_state = PartialState( - backend=self.ddp_backend, timeout=timedelta(seconds=self.ddp_timeout) - ) - self._n_gpu = 1 - if not is_sagemaker_mp_enabled(): - device = self.distributed_state.device - self.local_rank = self.distributed_state.local_process_index - if dist.is_available() and dist.is_initialized() and \ - self.parallel_mode != ParallelMode.DISTRIBUTED: - logger.warning( - "torch.distributed process group is initialized, " - "but parallel_mode != ParallelMode.DISTRIBUTED. " - "In order to use Torch DDP, launch your script with `python -m torch.distributed.launch" - ) - if is_torch_tpu_available(): - device = self.distributed_state.device - self._n_gpu = 0 - elif is_sagemaker_dp_enabled() or is_sagemaker_mp_enabled(): - # Already set _n_gpu - pass - elif self.distributed_state.distributed_type == DistributedType.MULTI_XPU: - if "ACCELERATE_USE_XPU" not in os.environ: - os.environ["ACCELERATE_USE_XPU"] = "true" - # self._n_gpu = torch.xpu.device_count() - # device = torch.device("xpu:0") - # torch.xpu.set_device(device) - elif self.distributed_state.distributed_type == DistributedType.NO: - if self.use_mps_device: - warnings.warn( - "`use_mps_device` is deprecated and will be removed in" - " version 5.0 of 🤗 Transformers." - "`mps` device will be used by default if available similar" - " to the way `cuda` device is used." - "Therefore, no action from user is required. " - ) - if device.type != "mps": - invalidInputError(False, - ("Either you do not have an MPS-enabled device" - " on this machine or MacOS" - " version is not 12.3+ " - "or current PyTorch install was not built with MPS enabled.")) - if device.type == "mps": - self._n_gpu = 1 - elif self.use_cpu: - device = torch.device("cpu") - self._n_gpu = 0 - elif is_torch_xpu_available(): - device = torch.device("xpu:0") - torch.xpu.set_device(device) - self._n_gpu = 1 - elif is_torch_npu_available(): - device = torch.device("npu:0") - torch.npu.set_device(device) - self._n_gpu = 1 - else: - # if n_gpu is > 1 we'll use nn.DataParallel. - # If you only want to use a specific subset of GPUs use `CUDA_VISIBLE_DEVICES=0` - # Explicitly set CUDA to the first (index 0) CUDA device, otherwise `set_device` will - # trigger an error that a device index is missing. Index 0 takes into account the - # GPUs available in the environment, so `CUDA_VISIBLE_DEVICES=1,2` with `cuda:0` - # will use the first GPU in that env, i.e. GPU#1 - device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - # Sometimes the line in the postinit has not been run before we end up here, - # so just checking we're not at - # the default value. - self._n_gpu = torch.cuda.device_count() - if device.type == "cuda": - torch.cuda.set_device(device) - return device - from peft.mapping import PEFT_TYPE_TO_CONFIG_MAPPING PEFT_TYPE_TO_CONFIG_MAPPING["lora"] = LoraConfig -# workaround a IPEX bug that prevents resume training in bf16 -from accelerate import Accelerator -Accelerator._prepare_ipex = patch_prepare_ipex - -# patch transformer for xpu DDP traing -from transformers import TrainingArguments -TrainingArguments._setup_devices = _setup_devices - def cast_lora_weight(model, dtype=torch.bfloat16): for name, module in model.named_modules(): @@ -396,6 +324,9 @@ def cast_lora_weight(model, dtype=torch.bfloat16): module.compute_dtype = dtype if isinstance(module, LoraLayer): module = module.to(dtype) + if isinstance(module, BF16Linear): + module = module.to(dtype) + module.compute_dtype = dtype if 'norm' in name: module = module.to(torch.float32) if 'lm_head' in name or 'embed_tokens' in name: diff --git a/python/llm/src/bigdl/llm/transformers/training_patch.py b/python/llm/src/bigdl/llm/transformers/training_patch.py new file mode 100644 index 00000000..0c1addd6 --- /dev/null +++ b/python/llm/src/bigdl/llm/transformers/training_patch.py @@ -0,0 +1,198 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Some parts of this file is adapted from +# https://github.com/huggingface/peft/blob/v0.5.0/src/peft/tuners/lora.py +# +# coding=utf-8 +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Some parts of this file is adapted from +# https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/training_args.py +# +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +def patch_prepare_ipex(self, *args): + return tuple(args) + + +from transformers.utils import ( + requires_backends, + is_sagemaker_mp_enabled, + is_accelerate_available, + is_torch_xpu_available, + is_sagemaker_dp_enabled, + is_torch_tpu_available, + is_torch_npu_available) +from transformers.utils.generic import strtobool +from transformers.utils import cached_property +from transformers.training_args import logger, ParallelMode, DistributedType +import torch +import torch.distributed as dist +import os +import warnings +from datetime import timedelta + +if is_accelerate_available(): + from accelerate.state import AcceleratorState, PartialState + from accelerate.utils import DistributedType + +if is_sagemaker_mp_enabled(): + import smdistributed.modelparallel.torch as smp + + smp.init() + + +@cached_property +def _setup_devices(self) -> "torch.device": + requires_backends(self, ["torch"]) + logger.info("PyTorch: setting up devices") + if not is_sagemaker_mp_enabled(): + if not is_accelerate_available(min_version="0.20.1"): + invalidInputError( + False, + "Using the `Trainer` with `PyTorch` requires `accelerate>=0.20.1`: " + "Please run `pip install transformers[torch]` or `pip install accelerate -U`" + ) + AcceleratorState._reset_state(reset_partial_state=True) + self.distributed_state = None + if not self.use_ipex and "ACCELERATE_USE_IPEX" not in os.environ: + os.environ["ACCELERATE_USE_IPEX"] = "false" + if self.use_cpu or strtobool(os.environ.get("ACCELERATE_USE_CPU", "False")): + self.distributed_state = PartialState(cpu=True, backend=self.ddp_backend) + self._n_gpu = 0 + elif is_sagemaker_mp_enabled(): + local_rank = smp.local_rank() + device = torch.device("cuda", local_rank) + self._n_gpu = 1 + torch.cuda.set_device(device) + elif is_torch_xpu_available() and "ACCELERATE_USE_XPU" not in os.environ: + os.environ["ACCELERATE_USE_XPU"] = "true" + self.distributed_state = PartialState(timeout=timedelta(seconds=self.ddp_timeout)) + # device = torch.device("xpu:0") + device = self.distributed_state.device + self._n_gpu = 1 + elif is_sagemaker_dp_enabled(): + self.distributed_state = PartialState(_use_sagemaker_dp=True) + self._n_gpu = 1 + elif self.deepspeed: + # Need to do similar for Accelerator init + os.environ["ACCELERATE_USE_DEEPSPEED"] = "true" + self.distributed_state = PartialState(timeout=timedelta(seconds=self.ddp_timeout)) + del os.environ["ACCELERATE_USE_DEEPSPEED"] + self._n_gpu = 1 + else: + self.distributed_state = PartialState( + backend=self.ddp_backend, timeout=timedelta(seconds=self.ddp_timeout) + ) + self._n_gpu = 1 + if not is_sagemaker_mp_enabled(): + device = self.distributed_state.device + self.local_rank = self.distributed_state.local_process_index + if dist.is_available() and dist.is_initialized() and \ + self.parallel_mode != ParallelMode.DISTRIBUTED: + logger.warning( + "torch.distributed process group is initialized, " + "but parallel_mode != ParallelMode.DISTRIBUTED. " + "In order to use Torch DDP, launch your script with `python -m torch.distributed.launch" + ) + if is_torch_tpu_available(): + device = self.distributed_state.device + self._n_gpu = 0 + elif is_sagemaker_dp_enabled() or is_sagemaker_mp_enabled(): + # Already set _n_gpu + pass + elif self.distributed_state.distributed_type == DistributedType.MULTI_XPU: + if "ACCELERATE_USE_XPU" not in os.environ: + os.environ["ACCELERATE_USE_XPU"] = "true" + # self._n_gpu = torch.xpu.device_count() + # device = torch.device("xpu:0") + # torch.xpu.set_device(device) + elif self.distributed_state.distributed_type == DistributedType.NO: + if self.use_mps_device: + warnings.warn( + "`use_mps_device` is deprecated and will be removed in" + " version 5.0 of 🤗 Transformers." + "`mps` device will be used by default if available similar" + " to the way `cuda` device is used." + "Therefore, no action from user is required. " + ) + if device.type != "mps": + invalidInputError(False, + ("Either you do not have an MPS-enabled device" + " on this machine or MacOS" + " version is not 12.3+ " + "or current PyTorch install was not built with MPS enabled.")) + if device.type == "mps": + self._n_gpu = 1 + elif self.use_cpu: + device = torch.device("cpu") + self._n_gpu = 0 + elif is_torch_xpu_available(): + device = torch.device("xpu:0") + torch.xpu.set_device(device) + self._n_gpu = 1 + elif is_torch_npu_available(): + device = torch.device("npu:0") + torch.npu.set_device(device) + self._n_gpu = 1 + else: + # if n_gpu is > 1 we'll use nn.DataParallel. + # If you only want to use a specific subset of GPUs use `CUDA_VISIBLE_DEVICES=0` + # Explicitly set CUDA to the first (index 0) CUDA device, otherwise `set_device` will + # trigger an error that a device index is missing. Index 0 takes into account the + # GPUs available in the environment, so `CUDA_VISIBLE_DEVICES=1,2` with `cuda:0` + # will use the first GPU in that env, i.e. GPU#1 + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + # Sometimes the line in the postinit has not been run before we end up here, + # so just checking we're not at + # the default value. + self._n_gpu = torch.cuda.device_count() + if device.type == "cuda": + torch.cuda.set_device(device) + return device + +# remove ipex.optimize +from accelerate import Accelerator +Accelerator._prepare_ipex = patch_prepare_ipex + +# patch transformer for xpu DDP traing +from transformers import TrainingArguments +TrainingArguments._setup_devices = _setup_devices