diff --git a/python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/README.md b/python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/README.md index d3cce346..29b81e82 100644 --- a/python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/README.md +++ b/python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/README.md @@ -101,6 +101,12 @@ bash qalora_finetune_llama2_7b_pvc_1550_1_tile.sh #### LoRA +##### Finetuning LLaMA2-7B on single Arc A770 + +```bash +bash lora_finetune_llama2_7b_arc_1_card.sh +``` + ##### Finetuning LLaMA2-7B on four Intel Data Center GPU Max 1100 ```bash diff --git a/python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/lora_finetune_llama2_7b_arc_1_card.sh b/python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/lora_finetune_llama2_7b_arc_1_card.sh new file mode 100644 index 00000000..dcb4a82d --- /dev/null +++ b/python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/lora_finetune_llama2_7b_arc_1_card.sh @@ -0,0 +1,26 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# You could also specify `--base_model` to the local path of the huggingface model checkpoint folder and `--data_path` to the local path of the dataset JSON file +python ./alpaca_qlora_finetuning.py \ + --micro_batch_size 8 \ + --batch_size 128 \ + --base_model "meta-llama/Llama-2-7b-hf" \ + --data_path "yahma/alpaca-cleaned" \ + --output_dir "./bigdl-lora-alpaca" \ + --gradient_checkpointing True \ + --lora_target_modules "['k_proj', 'q_proj', 'o_proj', 'v_proj']" \ + --training_mode "lora" \ No newline at end of file diff --git a/python/llm/example/GPU/QLoRA-FineTuning/export_merged_model.py b/python/llm/example/GPU/QLoRA-FineTuning/export_merged_model.py index 06792386..b7474ab8 100644 --- a/python/llm/example/GPU/QLoRA-FineTuning/export_merged_model.py +++ b/python/llm/example/GPU/QLoRA-FineTuning/export_merged_model.py @@ -54,7 +54,8 @@ if __name__ == "__main__": tokenizer = LlamaTokenizer.from_pretrained(base_model) lora_config = LoraConfig.from_json_file(os.path.join(adapter_path, "adapter_config.json")) - qa_lora = lora_config.get("qa_lora", False) + training_mode = lora_config.get("training_mode", "qlora") + qa_lora = training_mode == "qalora" temp_dir = None if qa_lora: diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index ca0b9f86..2646b074 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -262,6 +262,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\ .to(device_type) elif qtype == ggml_tensor_qtype["bf16"]: + module.to(torch.bfloat16) new_linear = BF16Linear( in_features, out_features, @@ -344,7 +345,7 @@ def _optimize_pre(model): def ggml_convert_low_bit(model, qtype, optimize_model=True, convert_shape_only=False, device="cpu", modules_to_not_convert=None, cpu_embedding=False, - lightweight_bmm=False): + lightweight_bmm=False, torch_dtype="auto"): logger.info(f"Converting the current model to " f"{list(ggml_tensor_qtype.keys())[list(ggml_tensor_qtype.values()).index(qtype)]} " f"format......") @@ -366,7 +367,10 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True, ) elif device == "cpu": if not (getattr(model, "quantization_method", None) == "gptq"): - model.to(torch.float32) + if torch_dtype == "auto": + convert_bigdl_other_module(model, torch.float32) + else: + convert_bigdl_other_module(model, torch_dtype) elif device == "meta": # Do nothing here for weights are empty. pass @@ -376,6 +380,17 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True, return model +def convert_bigdl_other_module(model, dtype): + # Convert modules outside of bigdl linear to corresponding dtype + from bigdl.llm.transformers.low_bit_linear import LowBitLinear, \ + FP16Linear, BF16Linear + for module in model.modules(): + if list(module.children()) == []: + # leaf module + if not isinstance(module, (LowBitLinear, FP16Linear, BF16Linear)): + module.to(dtype) + + def convert_forward(m, target_m, new_forward): for _, sub_m in m.named_children(): if isinstance(sub_m, target_m): diff --git a/python/llm/src/bigdl/llm/transformers/low_bit_linear.py b/python/llm/src/bigdl/llm/transformers/low_bit_linear.py index a48e0f08..8337d345 100644 --- a/python/llm/src/bigdl/llm/transformers/low_bit_linear.py +++ b/python/llm/src/bigdl/llm/transformers/low_bit_linear.py @@ -599,27 +599,18 @@ class BF16Linear(nn.Linear): self.out_len = output_features self.weight_shape = (self.out_len, self.in_len) self.weight_length = self.out_len * self.in_len + self.qtype = ggml_tensor_qtype["bf16"] self.mp_group = mp_group self.compute_dtype = compute_dtype def forward(self, x: torch.Tensor): - # only work for GPU now - invalidInputError(x.device.type == "xpu", - "bf16 only works for GPU now") - is_training = self.training and not torch.is_inference_mode_enabled() - if is_training: - # below logic is only for training - autocast_dtype = get_autocast_dtype(x) - if self.compute_dtype is not None and x.device.type == "xpu": - x = x.to(self.compute_dtype) # solve GC issue for unlora module - elif autocast_dtype is not None: - x = x.to(autocast_dtype) - + x = x.to(torch.bfloat16) + if self.weight is not None and self.weight.dtype != x.dtype: + self.weight.data = self.weight.data.to(x.dtype) if self.bias is not None and self.bias.dtype != x.dtype: self.bias.data = self.bias.data.to(x.dtype) result = F.linear(x, self.weight) if self.bias is not None: result += self.bias - return result.to(x.dtype) diff --git a/python/llm/src/bigdl/llm/transformers/model.py b/python/llm/src/bigdl/llm/transformers/model.py index 8c6930dc..bb5ce960 100644 --- a/python/llm/src/bigdl/llm/transformers/model.py +++ b/python/llm/src/bigdl/llm/transformers/model.py @@ -304,7 +304,8 @@ class _BaseAutoModelClass: model = model.to("cpu") model = ggml_convert_low_bit(model, qtype, optimize_model, modules_to_not_convert=modules_to_not_convert, - cpu_embedding=cpu_embedding, lightweight_bmm=lightweight_bmm) + cpu_embedding=cpu_embedding, lightweight_bmm=lightweight_bmm, + torch_dtype=kwargs.get("torch_dtype", 'auto')) model.config.update({"bigdl_transformers_low_bit": q_k}) model.config.update({"tie_word_embeddings": False}) diff --git a/python/llm/src/bigdl/llm/transformers/models/llama.py b/python/llm/src/bigdl/llm/transformers/models/llama.py index 55bb0f26..56ccc211 100644 --- a/python/llm/src/bigdl/llm/transformers/models/llama.py +++ b/python/llm/src/bigdl/llm/transformers/models/llama.py @@ -519,6 +519,9 @@ def check_flash_attention_available(query): # ipex flash attention is only supported for xetla # may update this later return False + if query.dtype not in [torch.float32, torch.float16]: + # only use flash attention for fp32/fp16 input + return False return True diff --git a/python/llm/src/bigdl/llm/transformers/qlora.py b/python/llm/src/bigdl/llm/transformers/qlora.py index c4ab77bc..739529ff 100644 --- a/python/llm/src/bigdl/llm/transformers/qlora.py +++ b/python/llm/src/bigdl/llm/transformers/qlora.py @@ -49,10 +49,12 @@ # limitations under the License. import torch +from torch.nn import Linear, Embedding from bigdl.llm.transformers.low_bit_linear import LowBitLinear, BF16Linear, get_qk_size from peft.tuners.lora import LoraLayer from bigdl.llm.utils.common import invalidInputError from bigdl.llm.transformers.utils import get_autocast_dtype +from bigdl.llm.ggml.quantize import ggml_tensor_qtype import functools from bigdl.llm.transformers import training_patch @@ -275,9 +277,19 @@ def prepare_model_for_kbit_training(model, use_gradient_checkpointing=True): if not is_gptq_quantized: # cast all non INT8 parameters to fp32 - for param in model.parameters(): - if (param.dtype == torch.float16) or (param.dtype == torch.bfloat16): - param.data = param.data.to(torch.float32) + # for param in model.parameters(): + # if (param.dtype == torch.float16) or (param.dtype == torch.bfloat16): + # param.data = param.data.to(torch.float32) + + # change to below way to reduce memory for Linear + # otherwise lora finetuning on arc may OOM at this convert + for module in model.modules(): + if list(module.children()) == []: + # leaf module + if not isinstance(module, (Linear, Embedding)): + for param in module.parameters(): + if (param.dtype == torch.float16) or (param.dtype == torch.bfloat16): + param.data = param.data.to(torch.float32) if use_gradient_checkpointing: # For backward compatibility