Upgrade Peft version to 0.10.0 for LLM finetune (#10886)
* Upgrade Peft version to 0.10.0 * Upgrade Peft version in ARC unit test and HF-Peft example.
This commit is contained in:
parent
0efe26c3b6
commit
d7ca5d935b
5 changed files with 105 additions and 90 deletions
2
.github/workflows/llm_unit_tests.yml
vendored
2
.github/workflows/llm_unit_tests.yml
vendored
|
|
@ -370,7 +370,7 @@ jobs:
|
|||
shell: bash
|
||||
run: |
|
||||
python -m pip uninstall datasets -y
|
||||
python -m pip install transformers==4.34.0 datasets peft==0.5.0 accelerate==0.23.0
|
||||
python -m pip install transformers==4.36.0 datasets peft==0.10.0 accelerate==0.23.0
|
||||
python -m pip install bitsandbytes scipy
|
||||
# Specific oneapi position on arc ut test machines
|
||||
if [[ '${{ matrix.pytorch-version }}' == '2.1' ]]; then
|
||||
|
|
|
|||
|
|
@ -14,8 +14,8 @@ conda create -n llm python=3.11
|
|||
conda activate llm
|
||||
# below command will install intel_extension_for_pytorch==2.1.10+xpu as default
|
||||
pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
|
||||
pip install transformers==4.34.0 datasets
|
||||
pip install fire peft==0.5.0
|
||||
pip install transformers==4.36.0 datasets
|
||||
pip install fire peft==0.10.0
|
||||
pip install oneccl_bind_pt==2.1.100 --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ # necessary to run distributed finetuning
|
||||
pip install accelerate==0.23.0
|
||||
pip install bitsandbytes scipy
|
||||
|
|
|
|||
|
|
@ -14,8 +14,8 @@ conda create -n llm python=3.11
|
|||
conda activate llm
|
||||
# below command will install intel_extension_for_pytorch==2.1.10+xpu as default
|
||||
pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
|
||||
pip install transformers==4.34.0 datasets
|
||||
pip install fire peft==0.5.0
|
||||
pip install transformers==4.36.0 datasets
|
||||
pip install fire peft==0.10.0
|
||||
pip install oneccl_bind_pt==2.1.100 --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ # necessary to run distributed finetuning
|
||||
pip install accelerate==0.23.0
|
||||
pip install bitsandbytes scipy
|
||||
|
|
|
|||
|
|
@ -17,6 +17,8 @@
|
|||
import transformers
|
||||
import importlib
|
||||
import sys
|
||||
from packaging import version
|
||||
|
||||
from ipex_llm.utils.common import invalidInputError
|
||||
from enum import Enum
|
||||
|
||||
|
|
@ -57,12 +59,14 @@ def llm_patch(train=False):
|
|||
import peft
|
||||
from ipex_llm.transformers.qlora import get_peft_model, prepare_model_for_kbit_training,\
|
||||
LoraConfig, TrainingArguments
|
||||
peft_version = peft.__version__
|
||||
replace_attr(transformers, "TrainingArguments", TrainingArguments)
|
||||
get_peft_model_original = getattr(peft, "get_peft_model")
|
||||
replace_attr(peft, "get_peft_model", get_peft_model)
|
||||
setattr(peft, "get_peft_model_original", get_peft_model_original)
|
||||
replace_attr(peft, "prepare_model_for_kbit_training", prepare_model_for_kbit_training)
|
||||
replace_attr(peft, "prepare_model_for_int8_training", prepare_model_for_kbit_training)
|
||||
if version.parse(peft_version) <= version.parse("0.5.0"):
|
||||
replace_attr(peft, "prepare_model_for_int8_training", prepare_model_for_kbit_training)
|
||||
replace_attr(peft, "LoraConfig", LoraConfig)
|
||||
bigdl_patched = 'Train'
|
||||
|
||||
|
|
|
|||
|
|
@ -50,9 +50,10 @@
|
|||
|
||||
import torch
|
||||
import logging
|
||||
from torch.nn import Linear, Embedding
|
||||
from torch.nn import Linear, Embedding, Module
|
||||
from ipex_llm.transformers.low_bit_linear import LowBitLinear, BF16Linear, get_qk_size
|
||||
from peft.tuners.lora import LoraLayer
|
||||
from typing import Any, Optional, Union
|
||||
from ipex_llm.utils.common import invalidInputError
|
||||
from ipex_llm.transformers.utils import get_autocast_dtype
|
||||
from ipex_llm.ggml.quantize import ggml_tensor_qtype
|
||||
|
|
@ -62,38 +63,46 @@ from ipex_llm.transformers import training_patch
|
|||
LOG = logging.getLogger("ipex_llm.qlora")
|
||||
|
||||
|
||||
class LoraLowBitLinear(LowBitLinear, LoraLayer):
|
||||
class LoraLowBitLinear(Module, LoraLayer):
|
||||
# Lora implemented in a dense layer
|
||||
def __init__(
|
||||
self,
|
||||
base_layer,
|
||||
adapter_name,
|
||||
in_features,
|
||||
out_features,
|
||||
r: int = 0,
|
||||
lora_alpha: int = 1,
|
||||
lora_dropout: float = 0.0,
|
||||
qa_lora: bool = True,
|
||||
# Set this to True if the layer to replace stores weight like (fan_in, fan_out)
|
||||
fan_in_fan_out: bool = False,
|
||||
is_target_conv_1d_layer: bool = False,
|
||||
init_lora_weights: Union[bool, str]=True,
|
||||
use_rslora: bool = False,
|
||||
use_dora: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
LowBitLinear.__init__(
|
||||
self,
|
||||
in_features,
|
||||
out_features,
|
||||
qtype=kwargs.get("qtype"),
|
||||
bias=kwargs.get("bias", True),
|
||||
conver_to_half=False,
|
||||
)
|
||||
|
||||
super().__init__()
|
||||
qk_size = get_qk_size(kwargs.get("qtype"))
|
||||
lora_in_features = in_features // qk_size if qa_lora else in_features
|
||||
LoraLayer.__init__(self, in_features=lora_in_features, out_features=out_features)
|
||||
if qa_lora:
|
||||
# qa_lora need to change the in_features of the base_layer
|
||||
in_features = base_layer.in_features
|
||||
base_layer.in_features = in_features // qk_size
|
||||
|
||||
# Freezing the pre-trained weight matrix
|
||||
self.weight.requires_grad = False
|
||||
LoraLayer.__init__(self, base_layer, **kwargs)
|
||||
|
||||
self.fan_in_fan_out = fan_in_fan_out
|
||||
self._active_adapter = adapter_name
|
||||
self.update_layer(
|
||||
adapter_name,
|
||||
r,
|
||||
lora_alpha=lora_alpha,
|
||||
lora_dropout=lora_dropout,
|
||||
init_lora_weights=init_lora_weights,
|
||||
use_rslora=use_rslora,
|
||||
use_dora=use_dora,
|
||||
)
|
||||
self.is_target_conv_1d_layer = is_target_conv_1d_layer
|
||||
|
||||
init_lora_weights = kwargs.pop("init_lora_weights", True)
|
||||
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
|
||||
self.active_adapter = adapter_name
|
||||
if qa_lora:
|
||||
self.qa_pool = torch.nn.AvgPool1d(qk_size)
|
||||
else:
|
||||
|
|
@ -106,62 +115,66 @@ class LoraLowBitLinear(LowBitLinear, LoraLayer):
|
|||
x = x.to(torch.bfloat16)
|
||||
elif autocast_dtype is not None:
|
||||
x = x.to(autocast_dtype)
|
||||
result = super().forward(x)
|
||||
result = self.base_layer.forward(x)
|
||||
|
||||
if self.disable_adapters or self.active_adapter not in self.lora_A.keys():
|
||||
if self.disable_adapters or self.merged:
|
||||
return result
|
||||
elif self.r[self.active_adapter] > 0:
|
||||
result = result.clone()
|
||||
else:
|
||||
if autocast_dtype is None and x.device.type == "cpu":
|
||||
expected_dtype = result.dtype
|
||||
x = x.to(self.lora_A[self.active_adapter].weight.dtype)
|
||||
output = (
|
||||
self.lora_B[self.active_adapter](
|
||||
self.lora_A[self.active_adapter](
|
||||
self.lora_dropout[self.active_adapter](self.qa_pool(x)))
|
||||
).to(expected_dtype)
|
||||
* self.scaling[self.active_adapter]
|
||||
)
|
||||
for active_adapter in self.active_adapters:
|
||||
if active_adapter not in self.lora_A.keys():
|
||||
continue
|
||||
x = x.to(self.lora_A[active_adapter].weight.dtype)
|
||||
lora_A = self.lora_A[active_adapter]
|
||||
lora_B = self.lora_B[active_adapter]
|
||||
dropout = self.lora_dropout[active_adapter]
|
||||
scaling = self.scaling[active_adapter]
|
||||
result += lora_B(lora_A(dropout(self.qa_pool(x)))).to(expected_dtype) * scaling
|
||||
else:
|
||||
output = (
|
||||
self.lora_B[self.active_adapter](
|
||||
self.lora_A[self.active_adapter](
|
||||
self.lora_dropout[self.active_adapter](self.qa_pool(x)))
|
||||
)
|
||||
* self.scaling[self.active_adapter]
|
||||
)
|
||||
result += output
|
||||
for active_adapter in self.active_adapters:
|
||||
if active_adapter not in self.lora_A.keys():
|
||||
continue
|
||||
lora_A = self.lora_A[active_adapter]
|
||||
lora_B = self.lora_B[active_adapter]
|
||||
dropout = self.lora_dropout[active_adapter]
|
||||
scaling = self.scaling[active_adapter]
|
||||
result += lora_B(lora_A(dropout(self.qa_pool(x)))) * scaling
|
||||
return result
|
||||
|
||||
|
||||
class LoraBF16Linear(BF16Linear, LoraLayer):
|
||||
class LoraBF16Linear(Module, LoraLayer):
|
||||
# Lora implemented in a dense layer
|
||||
def __init__(
|
||||
self,
|
||||
base_layer,
|
||||
adapter_name,
|
||||
in_features,
|
||||
out_features,
|
||||
r: int = 0,
|
||||
lora_alpha: int = 1,
|
||||
lora_dropout: float = 0.0,
|
||||
# Set this to True if the layer to replace stores weight like (fan_in, fan_out)
|
||||
fan_in_fan_out: bool = False,
|
||||
is_target_conv_1d_layer: bool = False,
|
||||
init_lora_weights: Union[bool, str]=True,
|
||||
use_rslora: bool = False,
|
||||
use_dora: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
BF16Linear.__init__(
|
||||
self,
|
||||
in_features,
|
||||
out_features,
|
||||
bias=kwargs.get("bias", True),
|
||||
compute_dtype=torch.bfloat16,
|
||||
super().__init__()
|
||||
LoraLayer.__init__(self, base_layer, **kwargs)
|
||||
|
||||
self.fan_in_fan_out = fan_in_fan_out
|
||||
self._active_adapter = adapter_name
|
||||
self.update_layer(
|
||||
adapter_name,
|
||||
r,
|
||||
lora_alpha=lora_alpha,
|
||||
lora_dropout=lora_dropout,
|
||||
init_lora_weights=init_lora_weights,
|
||||
use_rslora=use_rslora,
|
||||
use_dora=use_dora,
|
||||
)
|
||||
|
||||
LoraLayer.__init__(self, in_features=in_features, out_features=out_features)
|
||||
|
||||
# Freezing the pre-trained weight matrix
|
||||
self.weight.requires_grad = False
|
||||
|
||||
init_lora_weights = kwargs.pop("init_lora_weights", True)
|
||||
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
|
||||
self.active_adapter = adapter_name
|
||||
self.is_target_conv_1d_layer = is_target_conv_1d_layer
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
autocast_dtype = get_autocast_dtype(x)
|
||||
|
|
@ -170,31 +183,31 @@ class LoraBF16Linear(BF16Linear, LoraLayer):
|
|||
x = x.to(torch.bfloat16)
|
||||
elif autocast_dtype is not None:
|
||||
x = x.to(autocast_dtype)
|
||||
result = super().forward(x)
|
||||
result = self.base_layer.forward(x)
|
||||
|
||||
if self.disable_adapters or self.active_adapter not in self.lora_A.keys():
|
||||
if self.disable_adapters or self.merged:
|
||||
return result
|
||||
elif self.r[self.active_adapter] > 0:
|
||||
result = result.clone()
|
||||
else:
|
||||
if autocast_dtype is None and x.device.type == "cpu":
|
||||
expected_dtype = result.dtype
|
||||
x = x.to(self.lora_A[self.active_adapter].weight.dtype)
|
||||
output = (
|
||||
self.lora_B[self.active_adapter](
|
||||
self.lora_A[self.active_adapter](
|
||||
self.lora_dropout[self.active_adapter](x))
|
||||
).to(expected_dtype)
|
||||
* self.scaling[self.active_adapter]
|
||||
)
|
||||
for active_adapter in self.active_adapters:
|
||||
if active_adapter not in self.lora_A.keys():
|
||||
continue
|
||||
x = x.to(self.lora_A[active_adapter].weight.dtype)
|
||||
lora_A = self.lora_A[active_adapter]
|
||||
lora_B = self.lora_B[active_adapter]
|
||||
dropout = self.lora_dropout[active_adapter]
|
||||
scaling = self.scaling[active_adapter]
|
||||
result += lora_B(lora_A(dropout(x))).to(expected_dtype) * scaling
|
||||
else:
|
||||
output = (
|
||||
self.lora_B[self.active_adapter](
|
||||
self.lora_A[self.active_adapter](
|
||||
self.lora_dropout[self.active_adapter](x))
|
||||
)
|
||||
* self.scaling[self.active_adapter]
|
||||
)
|
||||
result += output
|
||||
for active_adapter in self.active_adapters:
|
||||
if active_adapter not in self.lora_A.keys():
|
||||
continue
|
||||
lora_A = self.lora_A[active_adapter]
|
||||
lora_B = self.lora_B[active_adapter]
|
||||
dropout = self.lora_dropout[active_adapter]
|
||||
scaling = self.scaling[active_adapter]
|
||||
result += lora_B(lora_A(dropout(x))) * scaling
|
||||
return result
|
||||
|
||||
|
||||
|
|
@ -205,9 +218,8 @@ def _create_new_module(create_new_module_func, lora_config, adapter_name, target
|
|||
bias = low_bit_kwargs.pop("bias", False)
|
||||
|
||||
if hasattr(lora_config, "training_mode") and lora_config.training_mode == "lora":
|
||||
new_module = LoraBF16Linear(adapter_name,
|
||||
target.in_features,
|
||||
target.out_features,
|
||||
new_module = LoraBF16Linear(target,
|
||||
adapter_name,
|
||||
bias=bias,
|
||||
**low_bit_kwargs)
|
||||
else:
|
||||
|
|
@ -221,9 +233,8 @@ def _create_new_module(create_new_module_func, lora_config, adapter_name, target
|
|||
"qa_lora": qa_lora
|
||||
}
|
||||
)
|
||||
new_module = LoraLowBitLinear(adapter_name,
|
||||
target.in_features,
|
||||
target.out_features,
|
||||
new_module = LoraLowBitLinear(target,
|
||||
adapter_name,
|
||||
bias=bias,
|
||||
**low_bit_kwargs)
|
||||
else:
|
||||
|
|
|
|||
Loading…
Reference in a new issue