From 54d95e4907e6aed9812e390c422d7bf9904251b8 Mon Sep 17 00:00:00 2001 From: binbin Deng <108676127+plusbang@users.noreply.github.com> Date: Wed, 8 Nov 2023 16:25:17 +0800 Subject: [PATCH] LLM: add alpaca qlora finetuning example (#9276) --- .../QLoRA-FineTuning/alpaca-qlora/README.md | 50 +++ .../alpaca-qlora/alpaca_qlora_finetuning.py | 336 ++++++++++++++++++ .../alpaca-qlora/templates/alpaca.json | 7 + .../alpaca-qlora/templates/alpaca_legacy.json | 7 + .../alpaca-qlora/templates/alpaca_short.json | 7 + .../alpaca-qlora/templates/vigogne.json | 7 + .../alpaca-qlora/utils/prompter.py | 81 +++++ 7 files changed, 495 insertions(+) create mode 100644 python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/README.md create mode 100644 python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/alpaca_qlora_finetuning.py create mode 100644 python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/templates/alpaca.json create mode 100644 python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/templates/alpaca_legacy.json create mode 100644 python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/templates/alpaca_short.json create mode 100644 python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/templates/vigogne.json create mode 100644 python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/utils/prompter.py diff --git a/python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/README.md b/python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/README.md new file mode 100644 index 00000000..de24dd8a --- /dev/null +++ b/python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/README.md @@ -0,0 +1,50 @@ +# Alpaca QLoRA Finetuning (experimental support) + +This example ports [Alpaca-LoRA](https://github.com/tloen/alpaca-lora/tree/main) to BigDL-LLM QLoRA on [Intel GPUs](../../README.md). + +### 0. Requirements +To run this example with BigDL-LLM on Intel GPUs, we have some recommended requirements for your machine, please refer to [here](../../README.md#requirements) for more information. + +### 1. Install + +```bash +conda create -n llm python=3.9 +conda activate llm +# below command will install intel_extension_for_pytorch==2.0.110+xpu as default +# you can install specific ipex/torch version for your need +pip install --pre --upgrade bigdl-llm[xpu] -f https://developer.intel.com/ipex-whl-stable-xpu +pip install transformers==4.34.0 +pip install fire datasets peft==0.5.0 +pip install accelerate==0.23.0 +``` + +### 2. Configures OneAPI environment variables +```bash +source /opt/intel/oneapi/setvars.sh +``` + +### 3. Finetuning LLaMA-2-7B on a single Arc: + +Example usage: + +``` +python ./alpaca_qlora_finetuning.py \ + --base_model "meta-llama/Llama-2-7b-hf" \ + --data_path "yahma/alpaca-cleaned" \ + --output_dir "./bigdl-qlora-alpaca" +``` + +**Note**: You could also specify `--base_model` to the local path of the huggingface model checkpoint folder and `--data_path` to the local path of the dataset JSON file. + +#### Sample Output +```log +{'loss': 1.9231, 'learning_rate': 2.9999945367033285e-05, 'epoch': 0.0} +{'loss': 1.8622, 'learning_rate': 2.9999781468531096e-05, 'epoch': 0.01} +{'loss': 1.9043, 'learning_rate': 2.9999508305687345e-05, 'epoch': 0.01} +{'loss': 1.8967, 'learning_rate': 2.999912588049185e-05, 'epoch': 0.01} +{'loss': 1.9658, 'learning_rate': 2.9998634195730358e-05, 'epoch': 0.01} +{'loss': 1.8386, 'learning_rate': 2.9998033254984483e-05, 'epoch': 0.02} +{'loss': 1.809, 'learning_rate': 2.999732306263172e-05, 'epoch': 0.02} +{'loss': 1.8552, 'learning_rate': 2.9996503623845395e-05, 'epoch': 0.02} + 1%|█ | 8/1164 [xx:xx 0 or ( + "WANDB_PROJECT" in os.environ and len(os.environ["WANDB_PROJECT"]) > 0 + ) + # Only overwrite environ if wandb param passed + if len(wandb_project) > 0: + os.environ["WANDB_PROJECT"] = wandb_project + if len(wandb_watch) > 0: + os.environ["WANDB_WATCH"] = wandb_watch + if len(wandb_log_model) > 0: + os.environ["WANDB_LOG_MODEL"] = wandb_log_model + + if saved_low_bit_model is not None: + # Load the low bit optimized model if provide the saved path + model = AutoModelForCausalLM.load_low_bit( + saved_low_bit_model, + optimize_model=False, + torch_dtype=torch.bfloat16, + modules_to_not_convert=["lm_head"], + ) + else: + # Load the base model from a directory or the HF Hub to 4-bit NormalFloat format + model = AutoModelForCausalLM.from_pretrained( + base_model, + load_in_low_bit="nf4", # According to the QLoRA paper, using "nf4" could yield better model quality than "int4" + optimize_model=False, + torch_dtype=torch.bfloat16, + # device_map=device_map, + modules_to_not_convert=["lm_head"], + ) + print(f"Model loaded on rank {os.environ.get('LOCAL_RANK')}") + model = model.to(f'xpu:{os.environ.get("LOCAL_RANK", 0)}') + print(f"Model moved to rank {os.environ.get('LOCAL_RANK')}") + + tokenizer = LlamaTokenizer.from_pretrained(base_model) + print(f"Tokenizer loaded on rank {os.environ.get('LOCAL_RANK')}") + + tokenizer.pad_token_id = ( + 0 # unk. we want this to be different from the eos token + ) + tokenizer.padding_side = "left" # Allow batched inference + + print(model) + + def tokenize(prompt, add_eos_token=True): + # there's probably a way to do this with the tokenizer settings + # but again, gotta move fast + result = tokenizer( + prompt, + truncation=True, + max_length=cutoff_len, + padding=False, + return_tensors=None, + ) + if ( + result["input_ids"][-1] != tokenizer.eos_token_id + and len(result["input_ids"]) < cutoff_len + and add_eos_token + ): + result["input_ids"].append(tokenizer.eos_token_id) + result["attention_mask"].append(1) + + result["labels"] = result["input_ids"].copy() + + return result + + def generate_and_tokenize_prompt(data_point): + full_prompt = prompter.generate_prompt( + data_point["instruction"], + data_point["input"], + data_point["output"], + ) + tokenized_full_prompt = tokenize(full_prompt) + if not train_on_inputs: + user_prompt = prompter.generate_prompt( + data_point["instruction"], data_point["input"] + ) + tokenized_user_prompt = tokenize( + user_prompt, add_eos_token=add_eos_token + ) + user_prompt_len = len(tokenized_user_prompt["input_ids"]) + + if add_eos_token: + user_prompt_len -= 1 + + tokenized_full_prompt["labels"] = [ + -100 + ] * user_prompt_len + tokenized_full_prompt["labels"][ + user_prompt_len: + ] # could be sped up, probably + return tokenized_full_prompt + + # Prepare a BigDL-LLM compatible Peft model + model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=gradient_checkpointing) + + config = LoraConfig( + r=lora_r, + lora_alpha=lora_alpha, + target_modules=lora_target_modules, + lora_dropout=lora_dropout, + bias="none", + task_type="CAUSAL_LM", + ) + model = get_peft_model(model, config) + + if data_path.endswith(".json") or data_path.endswith(".jsonl"): + data = load_dataset("json", data_files=data_path) + else: + data = load_dataset(data_path) + + if resume_from_checkpoint: + # Check the available weights and load them + checkpoint_name = os.path.join( + resume_from_checkpoint, "pytorch_model.bin" + ) # Full checkpoint + if not os.path.exists(checkpoint_name): + checkpoint_name = os.path.join( + resume_from_checkpoint, "adapter_model.bin" + ) # only LoRA model - LoRA config above has to fit + resume_from_checkpoint = ( + False # So the trainer won't try loading its state + ) + # The two files above have a different name depending on how they were saved, but are actually the same. + if os.path.exists(checkpoint_name): + print(f"Restarting from {checkpoint_name}") + adapters_weights = torch.load(checkpoint_name) + set_peft_model_state_dict(model, adapters_weights) + else: + print(f"Checkpoint {checkpoint_name} not found") + + model.print_trainable_parameters() # Be more transparent about the % of trainable params. + + if val_set_size > 0: + train_val = data["train"].train_test_split( + test_size=val_set_size, shuffle=True, seed=42 + ) + train_data = ( + train_val["train"].shuffle().map(generate_and_tokenize_prompt) + ) + val_data = ( + train_val["test"].shuffle().map(generate_and_tokenize_prompt) + ) + else: + train_data = data["train"].shuffle().map(generate_and_tokenize_prompt) + val_data = None + + # Unused + # if not ddp and torch.cuda.device_count() > 1: + # # keeps Trainer from trying its own DataParallelism when more than 1 gpu is available + # model.is_parallelizable = True + # model.model_parallel = True + + trainer = transformers.Trainer( + model=model, + train_dataset=train_data, + eval_dataset=val_data, + args=transformers.TrainingArguments( + per_device_train_batch_size=micro_batch_size, + gradient_accumulation_steps=gradient_accumulation_steps, + # warmup_ratio=0.03, + # warmup_steps=100, + max_grad_norm=0.3, + num_train_epochs=num_epochs, + learning_rate=learning_rate, + lr_scheduler_type="cosine", + bf16=True, # ensure training more stable + logging_steps=1, + optim="adamw_torch", + evaluation_strategy="steps" if val_set_size > 0 else "no", + save_strategy="steps", + eval_steps=100 if val_set_size > 0 else None, + save_steps=100, + output_dir=output_dir, + save_total_limit=100, + load_best_model_at_end=True if val_set_size > 0 else False, + ddp_find_unused_parameters=False if ddp else None, + group_by_length=group_by_length, + report_to="wandb" if use_wandb else None, + run_name=wandb_run_name if use_wandb else None, + gradient_checkpointing=gradient_checkpointing, + ddp_backend="ccl", + deepspeed=deepspeed, + ), + data_collator=transformers.DataCollatorForSeq2Seq( + tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True + ), + ) + model.config.use_cache = False + + trainer.train(resume_from_checkpoint=resume_from_checkpoint) + + model.save_pretrained(output_dir) + + print( + "\n If there's a warning about missing keys above, please disregard :)" + ) + + +if __name__ == "__main__": + fire.Fire(train) diff --git a/python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/templates/alpaca.json b/python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/templates/alpaca.json new file mode 100644 index 00000000..3f4ae351 --- /dev/null +++ b/python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/templates/alpaca.json @@ -0,0 +1,7 @@ +{ + "//": "This file is copied from https://github.com/tloen/alpaca-lora/blob/main/templates/alpaca.json", + "description": "Template used by Alpaca-LoRA.", + "prompt_input": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n", + "prompt_no_input": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n", + "response_split": "### Response:" +} diff --git a/python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/templates/alpaca_legacy.json b/python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/templates/alpaca_legacy.json new file mode 100644 index 00000000..6414fdee --- /dev/null +++ b/python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/templates/alpaca_legacy.json @@ -0,0 +1,7 @@ +{ + "//": "This file is copied from https://github.com/tloen/alpaca-lora/blob/main/templates/alpaca_legacy.json", + "description": "Legacy template, used by Original Alpaca repository.", + "prompt_input": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:", + "prompt_no_input": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:", + "response_split": "### Response:" +} diff --git a/python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/templates/alpaca_short.json b/python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/templates/alpaca_short.json new file mode 100644 index 00000000..85ac49ef --- /dev/null +++ b/python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/templates/alpaca_short.json @@ -0,0 +1,7 @@ +{ + "//": "This file is copied from https://github.com/tloen/alpaca-lora/blob/main/templates/alpaca_short.json", + "description": "A shorter template to experiment with.", + "prompt_input": "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n", + "prompt_no_input": "### Instruction:\n{instruction}\n\n### Response:\n", + "response_split": "### Response:" +} diff --git a/python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/templates/vigogne.json b/python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/templates/vigogne.json new file mode 100644 index 00000000..4ca63fcc --- /dev/null +++ b/python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/templates/vigogne.json @@ -0,0 +1,7 @@ +{ + "//": "This file is copied from https://github.com/tloen/alpaca-lora/blob/main/templates/vigogne.json", + "description": "French template, used by Vigogne for finetuning.", + "prompt_input": "Ci-dessous se trouve une instruction qui décrit une tâche, associée à une entrée qui fournit un contexte supplémentaire. Écrivez une réponse qui complète correctement la demande.\n\n### Instruction:\n{instruction}\n\n### Entrée:\n{input}\n\n### Réponse:\n", + "prompt_no_input": "Ci-dessous se trouve une instruction qui décrit une tâche. Écrivez une réponse qui complète correctement la demande.\n\n### Instruction:\n{instruction}\n\n### Réponse:\n", + "response_split": "### Réponse:" +} diff --git a/python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/utils/prompter.py b/python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/utils/prompter.py new file mode 100644 index 00000000..33355129 --- /dev/null +++ b/python/llm/example/GPU/QLoRA-FineTuning/alpaca-qlora/utils/prompter.py @@ -0,0 +1,81 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Some parts of this file is adapted from +# https://github.com/tloen/alpaca-lora/blob/main/utils/prompter.py +# +# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import json +import os.path as osp +from typing import Union +from bigdl.llm.utils.common import invalidInputError + + +class Prompter(object): + __slots__ = ("template", "_verbose") + + def __init__(self, template_name: str = "", verbose: bool = False): + self._verbose = verbose + if not template_name: + # Enforce the default here, so the constructor can be called with '' and will not break. + template_name = "alpaca" + file_name = osp.join("templates", f"{template_name}.json") + if not osp.exists(file_name): + invalidInputError(False, f"Can't read {file_name}") + with open(file_name) as fp: + self.template = json.load(fp) + if self._verbose: + print( + f"Using prompt template {template_name}: {self.template['description']}" + ) + + def generate_prompt( + self, + instruction: str, + input: Union[None, str]=None, + label: Union[None, str]=None, + ) -> str: + # returns the full prompt from instruction and optional input + # if a label (=response, =output) is provided, it's also appended. + if input: + res = self.template["prompt_input"].format( + instruction=instruction, input=input + ) + else: + res = self.template["prompt_no_input"].format( + instruction=instruction + ) + if label: + res = f"{res}{label}" + if self._verbose: + print(res) + return res + + def get_response(self, output: str) -> str: + return output.split(self.template["response_split"])[1].strip()