From d7ca5d935b4b827946712bff208464fd1cc57c69 Mon Sep 17 00:00:00 2001 From: Qiyuan Gong Date: Tue, 7 May 2024 15:09:14 +0800 Subject: [PATCH] Upgrade Peft version to 0.10.0 for LLM finetune (#10886) * Upgrade Peft version to 0.10.0 * Upgrade Peft version in ARC unit test and HF-Peft example. --- .github/workflows/llm_unit_tests.yml | 2 +- .../GPU/LLM-Finetuning/HF-PEFT/README.md | 4 +- .../QLoRA/alpaca-qlora/README.md | 4 +- python/llm/src/ipex_llm/llm_patching.py | 6 +- python/llm/src/ipex_llm/transformers/qlora.py | 179 ++++++++++-------- 5 files changed, 105 insertions(+), 90 deletions(-) diff --git a/.github/workflows/llm_unit_tests.yml b/.github/workflows/llm_unit_tests.yml index 5eb5b55e..48073352 100644 --- a/.github/workflows/llm_unit_tests.yml +++ b/.github/workflows/llm_unit_tests.yml @@ -370,7 +370,7 @@ jobs: shell: bash run: | python -m pip uninstall datasets -y - python -m pip install transformers==4.34.0 datasets peft==0.5.0 accelerate==0.23.0 + python -m pip install transformers==4.36.0 datasets peft==0.10.0 accelerate==0.23.0 python -m pip install bitsandbytes scipy # Specific oneapi position on arc ut test machines if [[ '${{ matrix.pytorch-version }}' == '2.1' ]]; then diff --git a/python/llm/example/GPU/LLM-Finetuning/HF-PEFT/README.md b/python/llm/example/GPU/LLM-Finetuning/HF-PEFT/README.md index 7da65981..fb2a6288 100644 --- a/python/llm/example/GPU/LLM-Finetuning/HF-PEFT/README.md +++ b/python/llm/example/GPU/LLM-Finetuning/HF-PEFT/README.md @@ -14,8 +14,8 @@ conda create -n llm python=3.11 conda activate llm # below command will install intel_extension_for_pytorch==2.1.10+xpu as default pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ -pip install transformers==4.34.0 datasets -pip install fire peft==0.5.0 +pip install transformers==4.36.0 datasets +pip install fire peft==0.10.0 pip install oneccl_bind_pt==2.1.100 --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ # necessary to run distributed finetuning pip install accelerate==0.23.0 pip install bitsandbytes scipy diff --git a/python/llm/example/GPU/LLM-Finetuning/QLoRA/alpaca-qlora/README.md b/python/llm/example/GPU/LLM-Finetuning/QLoRA/alpaca-qlora/README.md index 4cdb3d26..10cdd1c6 100644 --- a/python/llm/example/GPU/LLM-Finetuning/QLoRA/alpaca-qlora/README.md +++ b/python/llm/example/GPU/LLM-Finetuning/QLoRA/alpaca-qlora/README.md @@ -14,8 +14,8 @@ conda create -n llm python=3.11 conda activate llm # below command will install intel_extension_for_pytorch==2.1.10+xpu as default pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ -pip install transformers==4.34.0 datasets -pip install fire peft==0.5.0 +pip install transformers==4.36.0 datasets +pip install fire peft==0.10.0 pip install oneccl_bind_pt==2.1.100 --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ # necessary to run distributed finetuning pip install accelerate==0.23.0 pip install bitsandbytes scipy diff --git a/python/llm/src/ipex_llm/llm_patching.py b/python/llm/src/ipex_llm/llm_patching.py index d68fac0c..aaf4fae4 100644 --- a/python/llm/src/ipex_llm/llm_patching.py +++ b/python/llm/src/ipex_llm/llm_patching.py @@ -17,6 +17,8 @@ import transformers import importlib import sys +from packaging import version + from ipex_llm.utils.common import invalidInputError from enum import Enum @@ -57,12 +59,14 @@ def llm_patch(train=False): import peft from ipex_llm.transformers.qlora import get_peft_model, prepare_model_for_kbit_training,\ LoraConfig, TrainingArguments + peft_version = peft.__version__ replace_attr(transformers, "TrainingArguments", TrainingArguments) get_peft_model_original = getattr(peft, "get_peft_model") replace_attr(peft, "get_peft_model", get_peft_model) setattr(peft, "get_peft_model_original", get_peft_model_original) replace_attr(peft, "prepare_model_for_kbit_training", prepare_model_for_kbit_training) - replace_attr(peft, "prepare_model_for_int8_training", prepare_model_for_kbit_training) + if version.parse(peft_version) <= version.parse("0.5.0"): + replace_attr(peft, "prepare_model_for_int8_training", prepare_model_for_kbit_training) replace_attr(peft, "LoraConfig", LoraConfig) bigdl_patched = 'Train' diff --git a/python/llm/src/ipex_llm/transformers/qlora.py b/python/llm/src/ipex_llm/transformers/qlora.py index 5c9f3b54..a9696f9c 100644 --- a/python/llm/src/ipex_llm/transformers/qlora.py +++ b/python/llm/src/ipex_llm/transformers/qlora.py @@ -50,9 +50,10 @@ import torch import logging -from torch.nn import Linear, Embedding +from torch.nn import Linear, Embedding, Module from ipex_llm.transformers.low_bit_linear import LowBitLinear, BF16Linear, get_qk_size from peft.tuners.lora import LoraLayer +from typing import Any, Optional, Union from ipex_llm.utils.common import invalidInputError from ipex_llm.transformers.utils import get_autocast_dtype from ipex_llm.ggml.quantize import ggml_tensor_qtype @@ -62,38 +63,46 @@ from ipex_llm.transformers import training_patch LOG = logging.getLogger("ipex_llm.qlora") -class LoraLowBitLinear(LowBitLinear, LoraLayer): +class LoraLowBitLinear(Module, LoraLayer): # Lora implemented in a dense layer def __init__( self, + base_layer, adapter_name, - in_features, - out_features, r: int = 0, lora_alpha: int = 1, lora_dropout: float = 0.0, qa_lora: bool = True, + # Set this to True if the layer to replace stores weight like (fan_in, fan_out) + fan_in_fan_out: bool = False, + is_target_conv_1d_layer: bool = False, + init_lora_weights: Union[bool, str]=True, + use_rslora: bool = False, + use_dora: bool = False, **kwargs, ): - LowBitLinear.__init__( - self, - in_features, - out_features, - qtype=kwargs.get("qtype"), - bias=kwargs.get("bias", True), - conver_to_half=False, - ) - + super().__init__() qk_size = get_qk_size(kwargs.get("qtype")) - lora_in_features = in_features // qk_size if qa_lora else in_features - LoraLayer.__init__(self, in_features=lora_in_features, out_features=out_features) + if qa_lora: + # qa_lora need to change the in_features of the base_layer + in_features = base_layer.in_features + base_layer.in_features = in_features // qk_size - # Freezing the pre-trained weight matrix - self.weight.requires_grad = False + LoraLayer.__init__(self, base_layer, **kwargs) + + self.fan_in_fan_out = fan_in_fan_out + self._active_adapter = adapter_name + self.update_layer( + adapter_name, + r, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + init_lora_weights=init_lora_weights, + use_rslora=use_rslora, + use_dora=use_dora, + ) + self.is_target_conv_1d_layer = is_target_conv_1d_layer - init_lora_weights = kwargs.pop("init_lora_weights", True) - self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights) - self.active_adapter = adapter_name if qa_lora: self.qa_pool = torch.nn.AvgPool1d(qk_size) else: @@ -106,62 +115,66 @@ class LoraLowBitLinear(LowBitLinear, LoraLayer): x = x.to(torch.bfloat16) elif autocast_dtype is not None: x = x.to(autocast_dtype) - result = super().forward(x) + result = self.base_layer.forward(x) - if self.disable_adapters or self.active_adapter not in self.lora_A.keys(): + if self.disable_adapters or self.merged: return result - elif self.r[self.active_adapter] > 0: - result = result.clone() + else: if autocast_dtype is None and x.device.type == "cpu": expected_dtype = result.dtype - x = x.to(self.lora_A[self.active_adapter].weight.dtype) - output = ( - self.lora_B[self.active_adapter]( - self.lora_A[self.active_adapter]( - self.lora_dropout[self.active_adapter](self.qa_pool(x))) - ).to(expected_dtype) - * self.scaling[self.active_adapter] - ) + for active_adapter in self.active_adapters: + if active_adapter not in self.lora_A.keys(): + continue + x = x.to(self.lora_A[active_adapter].weight.dtype) + lora_A = self.lora_A[active_adapter] + lora_B = self.lora_B[active_adapter] + dropout = self.lora_dropout[active_adapter] + scaling = self.scaling[active_adapter] + result += lora_B(lora_A(dropout(self.qa_pool(x)))).to(expected_dtype) * scaling else: - output = ( - self.lora_B[self.active_adapter]( - self.lora_A[self.active_adapter]( - self.lora_dropout[self.active_adapter](self.qa_pool(x))) - ) - * self.scaling[self.active_adapter] - ) - result += output + for active_adapter in self.active_adapters: + if active_adapter not in self.lora_A.keys(): + continue + lora_A = self.lora_A[active_adapter] + lora_B = self.lora_B[active_adapter] + dropout = self.lora_dropout[active_adapter] + scaling = self.scaling[active_adapter] + result += lora_B(lora_A(dropout(self.qa_pool(x)))) * scaling return result -class LoraBF16Linear(BF16Linear, LoraLayer): +class LoraBF16Linear(Module, LoraLayer): # Lora implemented in a dense layer def __init__( self, + base_layer, adapter_name, - in_features, - out_features, r: int = 0, lora_alpha: int = 1, lora_dropout: float = 0.0, + # Set this to True if the layer to replace stores weight like (fan_in, fan_out) + fan_in_fan_out: bool = False, + is_target_conv_1d_layer: bool = False, + init_lora_weights: Union[bool, str]=True, + use_rslora: bool = False, + use_dora: bool = False, **kwargs, ): - BF16Linear.__init__( - self, - in_features, - out_features, - bias=kwargs.get("bias", True), - compute_dtype=torch.bfloat16, + super().__init__() + LoraLayer.__init__(self, base_layer, **kwargs) + + self.fan_in_fan_out = fan_in_fan_out + self._active_adapter = adapter_name + self.update_layer( + adapter_name, + r, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + init_lora_weights=init_lora_weights, + use_rslora=use_rslora, + use_dora=use_dora, ) - - LoraLayer.__init__(self, in_features=in_features, out_features=out_features) - - # Freezing the pre-trained weight matrix - self.weight.requires_grad = False - - init_lora_weights = kwargs.pop("init_lora_weights", True) - self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights) - self.active_adapter = adapter_name + self.is_target_conv_1d_layer = is_target_conv_1d_layer def forward(self, x: torch.Tensor): autocast_dtype = get_autocast_dtype(x) @@ -170,31 +183,31 @@ class LoraBF16Linear(BF16Linear, LoraLayer): x = x.to(torch.bfloat16) elif autocast_dtype is not None: x = x.to(autocast_dtype) - result = super().forward(x) + result = self.base_layer.forward(x) - if self.disable_adapters or self.active_adapter not in self.lora_A.keys(): + if self.disable_adapters or self.merged: return result - elif self.r[self.active_adapter] > 0: - result = result.clone() + else: if autocast_dtype is None and x.device.type == "cpu": expected_dtype = result.dtype - x = x.to(self.lora_A[self.active_adapter].weight.dtype) - output = ( - self.lora_B[self.active_adapter]( - self.lora_A[self.active_adapter]( - self.lora_dropout[self.active_adapter](x)) - ).to(expected_dtype) - * self.scaling[self.active_adapter] - ) + for active_adapter in self.active_adapters: + if active_adapter not in self.lora_A.keys(): + continue + x = x.to(self.lora_A[active_adapter].weight.dtype) + lora_A = self.lora_A[active_adapter] + lora_B = self.lora_B[active_adapter] + dropout = self.lora_dropout[active_adapter] + scaling = self.scaling[active_adapter] + result += lora_B(lora_A(dropout(x))).to(expected_dtype) * scaling else: - output = ( - self.lora_B[self.active_adapter]( - self.lora_A[self.active_adapter]( - self.lora_dropout[self.active_adapter](x)) - ) - * self.scaling[self.active_adapter] - ) - result += output + for active_adapter in self.active_adapters: + if active_adapter not in self.lora_A.keys(): + continue + lora_A = self.lora_A[active_adapter] + lora_B = self.lora_B[active_adapter] + dropout = self.lora_dropout[active_adapter] + scaling = self.scaling[active_adapter] + result += lora_B(lora_A(dropout(x))) * scaling return result @@ -205,9 +218,8 @@ def _create_new_module(create_new_module_func, lora_config, adapter_name, target bias = low_bit_kwargs.pop("bias", False) if hasattr(lora_config, "training_mode") and lora_config.training_mode == "lora": - new_module = LoraBF16Linear(adapter_name, - target.in_features, - target.out_features, + new_module = LoraBF16Linear(target, + adapter_name, bias=bias, **low_bit_kwargs) else: @@ -221,9 +233,8 @@ def _create_new_module(create_new_module_func, lora_config, adapter_name, target "qa_lora": qa_lora } ) - new_module = LoraLowBitLinear(adapter_name, - target.in_features, - target.out_features, + new_module = LoraLowBitLinear(target, + adapter_name, bias=bias, **low_bit_kwargs) else: