diff --git a/python/llm/example/gpu/qlora_finetuning/README.md b/python/llm/example/gpu/qlora_finetuning/README.md index 7e14656c..7b98b1b6 100644 --- a/python/llm/example/gpu/qlora_finetuning/README.md +++ b/python/llm/example/gpu/qlora_finetuning/README.md @@ -1,4 +1,4 @@ -# Q-Lora (experimental support) +# Finetuning LLAMA Using Q-Lora (experimental support) This example demonstrates how to finetune a llama2-7b model use Big-LLM 4bit optimizations using [Intel GPUs](../README.md). @@ -7,7 +7,7 @@ To run this example with BigDL-LLM on Intel GPUs, we have some recommended requi ## Example: Finetune llama2-7b using qlora -This example is ported from [bnb-4bit-training](https://colab.research.google.com/drive/1VoYNfYDKcKRQRor98Zbf2-9VQTtGJ24k?usp=sharing) +This example is ported from [bnb-4bit-training](https://colab.research.google.com/drive/1VoYNfYDKcKRQRor98Zbf2-9VQTtGJ24k?usp=sharing). The `export_merged_model.py` is ported from [alpaca-lora](https://github.com/tloen/alpaca-lora/blob/main/export_hf_checkpoint.py). ### 1. Install @@ -26,13 +26,13 @@ pip install peft==0.5.0 source /opt/intel/oneapi/setvars.sh ``` -### 3. Run +### 3. Finetune model ``` python ./qlora_finetuning.py --repo-id-or-model-path REPO_ID_OR_MODEL_PATH ``` -### Sample Output +#### Sample Output ```log {'loss': 1.6134, 'learning_rate': 0.0002, 'epoch': 0.03} {'loss': 1.3038, 'learning_rate': 0.00017777777777777779, 'epoch': 0.06} @@ -47,4 +47,12 @@ python ./qlora_finetuning.py --repo-id-or-model-path REPO_ID_OR_MODEL_PATH {'train_runtime': 225.8005, 'train_samples_per_second': 3.543, 'train_steps_per_second': 0.886, 'train_loss': 1.211241865158081, 'epoch': 0.32} 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [03:45<00:00, 1.13s/it] TrainOutput(global_step=200, training_loss=1.211241865158081, metrics={'train_runtime': 225.8005, 'train_samples_per_second': 3.543, 'train_steps_per_second': 0.886, 'train_loss': 1.211241865158081, 'epoch': 0.32}) -``` \ No newline at end of file +``` + +### 4. Merge the adapter into the original model + +``` +python ./export_merged_model.py --repo-id-or-model-path REPO_ID_OR_MODEL_PATH --adapter_path ./outputs/checkpoint-200 --output_path ./outputs/checkpoint-200-merged +``` + +Then you can use `./outputs/checkpoint-200-merged` as a normal huggingface transformer model to do inference. diff --git a/python/llm/example/gpu/qlora_finetuning/export_merged_model.py b/python/llm/example/gpu/qlora_finetuning/export_merged_model.py new file mode 100644 index 00000000..1cf3c2ff --- /dev/null +++ b/python/llm/example/gpu/qlora_finetuning/export_merged_model.py @@ -0,0 +1,93 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This file is adapted from https://github.com/tloen/alpaca-lora/blob/main/export_hf_checkpoint.py +# +# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +import torch +import transformers +from transformers import LlamaTokenizer # noqa: F402 +from bigdl.llm.transformers.qlora import PeftModel +from bigdl.llm.transformers import AutoModelForCausalLM +import argparse + +if __name__ == "__main__": + + parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for Llama2 model') + parser.add_argument('--repo-id-or-model-path', type=str, default="meta-llama/Llama-2-7b-hf", + help='The huggingface repo id for the Llama2 (e.g. `meta-llama/Llama-2-7b-hf` and `meta-llama/Llama-2-13b-chat-hf`) to be downloaded' + ', or the path to the huggingface checkpoint folder') + parser.add_argument('--adapter_path', type=str,) + parser.add_argument('--output_path', type=str,) + + args = parser.parse_args() + base_model = model_path = args.repo_id_or_model_path + adapter_path = args.adapter_path + tokenizer = LlamaTokenizer.from_pretrained(base_model) + + base_model = AutoModelForCausalLM.from_pretrained( + base_model, + # load_in_low_bit="nf4", # should load the orignal model + torch_dtype=torch.float16, + device_map={"": "cpu"}, + ) + + first_weight = base_model.model.layers[0].self_attn.q_proj.weight + first_weight_old = first_weight.clone() + + lora_model = PeftModel.from_pretrained( + base_model, + adapter_path, + device_map={"": "cpu"}, + torch_dtype=torch.float16, + ) + + lora_weight = lora_model.base_model.model.model.layers[ + 0 + ].self_attn.q_proj.weight + + assert torch.allclose(first_weight_old, first_weight) + + # merge weights - new merging method from peft + lora_model = lora_model.merge_and_unload() + + lora_model.train(False) + + # did we do anything? + assert not torch.allclose(first_weight_old, first_weight) + + lora_model_sd = lora_model.state_dict() + deloreanized_sd = { + k.replace("base_model.model.", ""): v + for k, v in lora_model_sd.items() + if "lora" not in k + } + + base_model.save_pretrained(args.output_path, state_dict=deloreanized_sd) + tokenizer.save_pretrained(args.output_path) diff --git a/python/llm/example/gpu/qlora_finetuning/qlora_finetuning.py b/python/llm/example/gpu/qlora_finetuning/qlora_finetuning.py index 6531b483..85b5642e 100644 --- a/python/llm/example/gpu/qlora_finetuning/qlora_finetuning.py +++ b/python/llm/example/gpu/qlora_finetuning/qlora_finetuning.py @@ -45,8 +45,9 @@ if __name__ == "__main__": data = load_dataset(dataset_path) data = data.map(lambda samples: tokenizer(samples["quote"]), batched=True) model = AutoModelForCausalLM.from_pretrained(model_path, - load_in_4bit=True, + load_in_low_bit="nf4", optimize_model=False, + torch_dtype=torch.float16, modules_to_not_convert=["lm_head"],) model = model.to('xpu') model.gradient_checkpointing_enable() @@ -71,7 +72,8 @@ if __name__ == "__main__": warmup_steps=20, max_steps=200, learning_rate=2e-4, - fp16=False, # fp16 is not supported yet + save_steps=100, + fp16=True, logging_steps=20, output_dir="outputs", optim="adamw_hf", # paged_adamw_8bit is not supported yet diff --git a/python/llm/src/bigdl/llm/transformers/qlora.py b/python/llm/src/bigdl/llm/transformers/qlora.py index d2728f08..2b074105 100644 --- a/python/llm/src/bigdl/llm/transformers/qlora.py +++ b/python/llm/src/bigdl/llm/transformers/qlora.py @@ -36,6 +36,7 @@ import torch from bigdl.llm.transformers.low_bit_linear import LowBitLinear from peft.tuners.lora import LoraLayer from bigdl.llm.utils.common import invalidInputError +import functools class LoraLowBitLinear(LowBitLinear, LoraLayer): @@ -94,13 +95,11 @@ class LoraLowBitLinear(LowBitLinear, LoraLayer): return result -@staticmethod -def _create_new_module(lora_config, adapter_name, target, **kwargs): - - bias = kwargs.pop("bias", False) +def _create_new_module(create_new_module_func, lora_config, adapter_name, target, **kwargs): if isinstance(target, LowBitLinear): low_bit_kwargs = kwargs.copy() + bias = low_bit_kwargs.pop("bias", False) low_bit_kwargs.update( { "qtype": target.qtype, @@ -112,9 +111,7 @@ def _create_new_module(lora_config, adapter_name, target, **kwargs): bias=bias, **low_bit_kwargs) else: - invalidInputError(False, - f"Target module {target} is not supported. " - f"Currently, only `LowBitLinear` are supported.") + new_module = create_new_module_func(lora_config, adapter_name, target, **kwargs) return new_module @@ -124,7 +121,8 @@ from peft.tuners.lora import LoraModel def get_peft_model(*args, **kwargs): old_create_new_module = LoraModel._create_new_module - LoraModel._create_new_module = _create_new_module + LoraModel._create_new_module = staticmethod(functools.partial(_create_new_module, + old_create_new_module)) try: from peft import get_peft_model as get_peft_model_original model = get_peft_model_original(*args, **kwargs) @@ -181,7 +179,8 @@ class PeftModel: def from_pretrained(*args, **kwargs): old_create_new_module = LoraModel._create_new_module - LoraModel._create_new_module = _create_new_module + LoraModel._create_new_module = staticmethod(functools.partial(_create_new_module, + old_create_new_module)) from peft import PeftModel try: model = PeftModel.from_pretrained(*args, **kwargs)