diff --git a/python/llm/example/gpu/qlora_finetuning/README.md b/python/llm/example/gpu/qlora_finetuning/README.md new file mode 100644 index 00000000..7e14656c --- /dev/null +++ b/python/llm/example/gpu/qlora_finetuning/README.md @@ -0,0 +1,50 @@ +# Q-Lora (experimental support) + +This example demonstrates how to finetune a llama2-7b model use Big-LLM 4bit optimizations using [Intel GPUs](../README.md). + +## 0. Requirements +To run this example with BigDL-LLM on Intel GPUs, we have some recommended requirements for your machine, please refer to [here](../README.md#recommended-requirements) for more information. + +## Example: Finetune llama2-7b using qlora + +This example is ported from [bnb-4bit-training](https://colab.research.google.com/drive/1VoYNfYDKcKRQRor98Zbf2-9VQTtGJ24k?usp=sharing) + +### 1. Install + +```bash +conda create -n llm python=3.9 +conda activate llm +# below command will install intel_extension_for_pytorch==2.0.110+xpu as default +# you can install specific ipex/torch version for your need +pip install --pre --upgrade bigdl-llm[xpu] -f https://developer.intel.com/ipex-whl-stable-xpu +pip install git+https://github.com/huggingface/transformers.git@95fe0f5 +pip install peft==0.5.0 +``` + +### 2. Configures OneAPI environment variables +```bash +source /opt/intel/oneapi/setvars.sh +``` + +### 3. Run + +``` +python ./qlora_finetuning.py --repo-id-or-model-path REPO_ID_OR_MODEL_PATH +``` + +### Sample Output +```log +{'loss': 1.6134, 'learning_rate': 0.0002, 'epoch': 0.03} +{'loss': 1.3038, 'learning_rate': 0.00017777777777777779, 'epoch': 0.06} +{'loss': 1.2634, 'learning_rate': 0.00015555555555555556, 'epoch': 0.1} +{'loss': 1.2389, 'learning_rate': 0.00013333333333333334, 'epoch': 0.13} +{'loss': 1.0399, 'learning_rate': 0.00011111111111111112, 'epoch': 0.16} +{'loss': 1.0406, 'learning_rate': 8.888888888888889e-05, 'epoch': 0.19} +{'loss': 1.3114, 'learning_rate': 6.666666666666667e-05, 'epoch': 0.22} +{'loss': 0.9876, 'learning_rate': 4.4444444444444447e-05, 'epoch': 0.26} +{'loss': 1.1406, 'learning_rate': 2.2222222222222223e-05, 'epoch': 0.29} +{'loss': 1.1728, 'learning_rate': 0.0, 'epoch': 0.32} +{'train_runtime': 225.8005, 'train_samples_per_second': 3.543, 'train_steps_per_second': 0.886, 'train_loss': 1.211241865158081, 'epoch': 0.32} +100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [03:45<00:00, 1.13s/it] +TrainOutput(global_step=200, training_loss=1.211241865158081, metrics={'train_runtime': 225.8005, 'train_samples_per_second': 3.543, 'train_steps_per_second': 0.886, 'train_loss': 1.211241865158081, 'epoch': 0.32}) +``` \ No newline at end of file diff --git a/python/llm/example/gpu/qlora_finetuning/qlora_finetuning.py b/python/llm/example/gpu/qlora_finetuning/qlora_finetuning.py new file mode 100644 index 00000000..6531b483 --- /dev/null +++ b/python/llm/example/gpu/qlora_finetuning/qlora_finetuning.py @@ -0,0 +1,84 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import torch +import os +os.environ["ACCELERATE_USE_IPEX"] = "true" +os.environ["ACCELERATE_USE_XPU"] = "true" + +import transformers +from transformers import LlamaTokenizer + +from peft import LoraConfig +import intel_extension_for_pytorch as ipex +from bigdl.llm.transformers.qlora import get_peft_model, prepare_model_for_kbit_training +from bigdl.llm.transformers import AutoModelForCausalLM +from datasets import load_dataset +import argparse + +if __name__ == "__main__": + + parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for Llama2 model') + parser.add_argument('--repo-id-or-model-path', type=str, default="meta-llama/Llama-2-7b-hf", + help='The huggingface repo id for the Llama2 (e.g. `meta-llama/Llama-2-7b-hf` and `meta-llama/Llama-2-13b-chat-hf`) to be downloaded' + ', or the path to the huggingface checkpoint folder') + parser.add_argument('--dataset', type=str, default="Abirate/english_quotes") + + args = parser.parse_args() + model_path = args.repo_id_or_model_path + dataset_path = args.dataset + tokenizer = LlamaTokenizer.from_pretrained(model_path, trust_remote_code=True) + + data = load_dataset(dataset_path) + data = data.map(lambda samples: tokenizer(samples["quote"]), batched=True) + model = AutoModelForCausalLM.from_pretrained(model_path, + load_in_4bit=True, + optimize_model=False, + modules_to_not_convert=["lm_head"],) + model = model.to('xpu') + model.gradient_checkpointing_enable() + model = prepare_model_for_kbit_training(model) + config = LoraConfig( + r=8, + lora_alpha=32, + target_modules=["q_proj", "k_proj", "v_proj"], + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM" + ) + model = get_peft_model(model, config) + tokenizer.pad_token_id = 0 + tokenizer.padding_side = "left" + trainer = transformers.Trainer( + model=model, + train_dataset=data["train"], + args=transformers.TrainingArguments( + per_device_train_batch_size=4, + gradient_accumulation_steps= 1, + warmup_steps=20, + max_steps=200, + learning_rate=2e-4, + fp16=False, # fp16 is not supported yet + logging_steps=20, + output_dir="outputs", + optim="adamw_hf", # paged_adamw_8bit is not supported yet + # gradient_checkpointing=True, # can further reduce memory but slower + ), + data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False), + ) + model.config.use_cache = False # silence the warnings. Please re-enable for inference! + result = trainer.train() + print(result) diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index f0e20881..3e4f9f60 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -99,8 +99,9 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, def ggml_convert_low_bit(model, qtype, optimize_model=True, - convert_shape_only=False, device="cpu"): - modules_to_not_convert = [] # ["lm_head"] + convert_shape_only=False, device="cpu", + modules_to_not_convert=None): + modules_to_not_convert = [] if modules_to_not_convert is None else modules_to_not_convert model, has_been_replaced = _replace_with_low_bit_linear( model, qtype, modules_to_not_convert, None, convert_shape_only, diff --git a/python/llm/src/bigdl/llm/transformers/low_bit_linear.py b/python/llm/src/bigdl/llm/transformers/low_bit_linear.py index 69bab76e..89b2e7fb 100644 --- a/python/llm/src/bigdl/llm/transformers/low_bit_linear.py +++ b/python/llm/src/bigdl/llm/transformers/low_bit_linear.py @@ -284,8 +284,38 @@ def ggml_matmul_src1_x_src0_t(src0: torch.Tensor, return result_t +class MatMulLowBit(torch.autograd.Function): + + @staticmethod + def forward(ctx, A, weight): + ctx.is_empty = False + import linear_q4_0 + result = linear_q4_0.forward_new(A, weight.data, weight.qtype) + if any(ctx.needs_input_grad[:2]): + ctx.tensors = (A, weight) + else: + ctx.tensors = (None, None) + return result + + @staticmethod + def backward(ctx, grad_output): + import linear_q4_0 + if ctx.is_empty: + bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias) + return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None + req_gradA, _ = ctx.needs_input_grad + A, weight = ctx.tensors + grad_A, grad_weight = None, None + if req_gradA: + dequant_weight = linear_q4_0.dequant(A, weight.data, weight.qtype) + grad_A = torch.matmul(grad_output, dequant_weight.reshape(weight._shape)) + + return grad_A, grad_weight + + class LowBitLinear(nn.Linear): - def __init__(self, input_features, output_features, qtype, bias=True): + def __init__(self, input_features, output_features, qtype, bias=True, + conver_to_half=True): super().__init__(input_features, output_features, bias) self.weight = FP4Params(self.weight.data, requires_grad=False, @@ -295,6 +325,7 @@ class LowBitLinear(nn.Linear): self.weight_shape = (self.out_len, self.in_len) self.weight_length = self.out_len * self.in_len self.qtype = qtype + self.conver_to_half = conver_to_half def forward(self, x: torch.Tensor): if self.bias is not None and self.bias.dtype != x.dtype: @@ -317,10 +348,14 @@ class LowBitLinear(nn.Linear): if x_2d.is_contiguous() is False: x_2d = x_2d.contiguous() # current workaround to reduce first token latency of fp32 input - if x_2d.shape[0] > 1 and x_2d.dtype == torch.float32: + # sometimes fp16 cause nan and training instability + # disable the conversion when training + if self.conver_to_half and x_2d.shape[0] > 1 and x_2d.dtype == torch.float32: x_2d = x_2d.half() - # input format of linear_q4.forward is 1: input, 2: weight - result = linear_q4_0.forward_new(x_2d, x0, self.qtype) + if self.training and x_2d.requires_grad: + result = MatMulLowBit.apply(x_2d, self.weight) + else: + result = linear_q4_0.forward_new(x_2d, self.weight.data, self.weight.qtype) new_shape = x_shape[:-1] + (self.out_len,) result = result.view(new_shape) if self.bias is not None: diff --git a/python/llm/src/bigdl/llm/transformers/model.py b/python/llm/src/bigdl/llm/transformers/model.py index acbbceed..d0c4aa0f 100644 --- a/python/llm/src/bigdl/llm/transformers/model.py +++ b/python/llm/src/bigdl/llm/transformers/model.py @@ -108,6 +108,7 @@ class _BaseAutoModelClass: # In case it needs a second try, # `from_pretrained`` may pop items out in dict # and lead to args missing. + modules_to_not_convert = kwargs.pop("modules_to_not_convert", None) _args = copy.deepcopy(args) _kwargs = copy.deepcopy(kwargs) try: @@ -119,7 +120,8 @@ class _BaseAutoModelClass: model = cls.HF_Model.from_pretrained(*_args, **_kwargs) model.config.update({"bigdl_lcmu_enabled": False}) model = model.to("cpu") - model = ggml_convert_low_bit(model, qtype, optimize_model) + model = ggml_convert_low_bit(model, qtype, optimize_model, + modules_to_not_convert=modules_to_not_convert) model.config.update({"bigdl_transformers_low_bit": q_k}) model.config.update({"tie_word_embeddings": False}) @@ -155,6 +157,7 @@ class _BaseAutoModelClass: import copy import os + modules_to_not_convert = kwargs.pop("modules_to_not_convert", None) # Autofactory trust_remote_code = kwargs.pop("trust_remote_code", None) kwargs_orig = copy.deepcopy(kwargs) @@ -264,7 +267,8 @@ class _BaseAutoModelClass: # Loading args may differ based on their usage quant_device = "meta" if bigdl_lcmu_enabled else "cpu" - model = ggml_convert_low_bit(model, qtype, optimize_model, device=quant_device) + model = ggml_convert_low_bit(model, qtype, optimize_model, device=quant_device, + modules_to_not_convert=modules_to_not_convert) if is_sharded: loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"] diff --git a/python/llm/src/bigdl/llm/transformers/qlora.py b/python/llm/src/bigdl/llm/transformers/qlora.py new file mode 100644 index 00000000..d2728f08 --- /dev/null +++ b/python/llm/src/bigdl/llm/transformers/qlora.py @@ -0,0 +1,191 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Some parts of this file is adapted from +# https://github.com/huggingface/peft/blob/v0.5.0/src/peft/tuners/lora.py +# +# coding=utf-8 +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +from bigdl.llm.transformers.low_bit_linear import LowBitLinear +from peft.tuners.lora import LoraLayer +from bigdl.llm.utils.common import invalidInputError + + +class LoraLowBitLinear(LowBitLinear, LoraLayer): + # Lora implemented in a dense layer + def __init__( + self, + adapter_name, + in_features, + out_features, + r: int = 0, + lora_alpha: int = 1, + lora_dropout: float = 0.0, + **kwargs, + ): + LowBitLinear.__init__( + self, + in_features, + out_features, + qtype=kwargs.get("qtype"), + bias=kwargs.get("bias", True), + conver_to_half=False, + ) + LoraLayer.__init__(self, in_features=in_features, out_features=out_features) + + # Freezing the pre-trained weight matrix + self.weight.requires_grad = False + + init_lora_weights = kwargs.pop("init_lora_weights", True) + self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights) + self.active_adapter = adapter_name + + def forward(self, x: torch.Tensor): + result = super().forward(x) + + if self.disable_adapters or self.active_adapter not in self.lora_A.keys(): + return result + elif self.r[self.active_adapter] > 0: + result = result.clone() + if not torch.is_autocast_enabled(): + expected_dtype = result.dtype + x = x.to(self.lora_A[self.active_adapter].weight.dtype) + output = ( + self.lora_B[self.active_adapter]( + self.lora_A[self.active_adapter](self.lora_dropout[self.active_adapter](x)) + ).to(expected_dtype) + * self.scaling[self.active_adapter] + ) + else: + output = ( + self.lora_B[self.active_adapter]( + self.lora_A[self.active_adapter](self.lora_dropout[self.active_adapter](x)) + ) + * self.scaling[self.active_adapter] + ) + result += output + return result + + +@staticmethod +def _create_new_module(lora_config, adapter_name, target, **kwargs): + + bias = kwargs.pop("bias", False) + + if isinstance(target, LowBitLinear): + low_bit_kwargs = kwargs.copy() + low_bit_kwargs.update( + { + "qtype": target.qtype, + } + ) + new_module = LoraLowBitLinear(adapter_name, + target.in_features, + target.out_features, + bias=bias, + **low_bit_kwargs) + else: + invalidInputError(False, + f"Target module {target} is not supported. " + f"Currently, only `LowBitLinear` are supported.") + + return new_module + + +from peft.tuners.lora import LoraModel + + +def get_peft_model(*args, **kwargs): + old_create_new_module = LoraModel._create_new_module + LoraModel._create_new_module = _create_new_module + try: + from peft import get_peft_model as get_peft_model_original + model = get_peft_model_original(*args, **kwargs) + finally: + LoraModel._create_new_module = old_create_new_module + + return model + + +def prepare_model_for_kbit_training(model, use_gradient_checkpointing=True): + r""" + This method wraps the entire protocol for preparing a model before running a training. + This includes: + 1- Cast the layernorm in fp32 + 2- making output embedding layer require grads + 3- Add the upcasting of the lm head to fp32 + + Args: + model, (`transformers.PreTrainedModel`): + The loaded model from `transformers` + """ + + is_gptq_quantized = getattr(model, "quantization_method", None) == "gptq" + for name, param in model.named_parameters(): + # freeze base model's layers + param.requires_grad = False + + if not is_gptq_quantized: + # cast all non INT8 parameters to fp32 + for param in model.parameters(): + if (param.dtype == torch.float16) or (param.dtype == torch.bfloat16): + param.data = param.data.to(torch.float32) + + if use_gradient_checkpointing: + # For backward compatibility + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + else: + + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + + # enable gradient checkpointing for memory efficiency + model.gradient_checkpointing_enable() + + return model + + +class PeftModel: + + @staticmethod + def from_pretrained(*args, + **kwargs): + old_create_new_module = LoraModel._create_new_module + LoraModel._create_new_module = _create_new_module + from peft import PeftModel + try: + model = PeftModel.from_pretrained(*args, **kwargs) + finally: + LoraModel._create_new_module = old_create_new_module + + return model