ipex-llm/python/llm/src/ipex_llm/transformers/qlora.py
Qiyuan Gong d7ca5d935b
Upgrade Peft version to 0.10.0 for LLM finetune (#10886)
* Upgrade Peft version to 0.10.0
* Upgrade Peft version in ARC unit test and HF-Peft example.
2024-05-07 15:09:14 +08:00

407 lines
15 KiB
Python

#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Some parts of this file is adapted from
# https://github.com/huggingface/peft/blob/v0.5.0/src/peft/tuners/lora.py
#
# coding=utf-8
# Copyright 2023-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Some parts of this file is adapted from
# https://github.com/huggingface/peft/blob/v0.5.0/src/peft/tuners/lora.py
#
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import logging
from torch.nn import Linear, Embedding, Module
from ipex_llm.transformers.low_bit_linear import LowBitLinear, BF16Linear, get_qk_size
from peft.tuners.lora import LoraLayer
from typing import Any, Optional, Union
from ipex_llm.utils.common import invalidInputError
from ipex_llm.transformers.utils import get_autocast_dtype
from ipex_llm.ggml.quantize import ggml_tensor_qtype
import functools
from ipex_llm.transformers import training_patch
LOG = logging.getLogger("ipex_llm.qlora")
class LoraLowBitLinear(Module, LoraLayer):
# Lora implemented in a dense layer
def __init__(
self,
base_layer,
adapter_name,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
qa_lora: bool = True,
# Set this to True if the layer to replace stores weight like (fan_in, fan_out)
fan_in_fan_out: bool = False,
is_target_conv_1d_layer: bool = False,
init_lora_weights: Union[bool, str]=True,
use_rslora: bool = False,
use_dora: bool = False,
**kwargs,
):
super().__init__()
qk_size = get_qk_size(kwargs.get("qtype"))
if qa_lora:
# qa_lora need to change the in_features of the base_layer
in_features = base_layer.in_features
base_layer.in_features = in_features // qk_size
LoraLayer.__init__(self, base_layer, **kwargs)
self.fan_in_fan_out = fan_in_fan_out
self._active_adapter = adapter_name
self.update_layer(
adapter_name,
r,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
init_lora_weights=init_lora_weights,
use_rslora=use_rslora,
use_dora=use_dora,
)
self.is_target_conv_1d_layer = is_target_conv_1d_layer
if qa_lora:
self.qa_pool = torch.nn.AvgPool1d(qk_size)
else:
self.qa_pool = torch.nn.Identity()
def forward(self, x: torch.Tensor):
autocast_dtype = get_autocast_dtype(x)
if x.device.type == "xpu":
# force to use bf16 on gpu
x = x.to(torch.bfloat16)
elif autocast_dtype is not None:
x = x.to(autocast_dtype)
result = self.base_layer.forward(x)
if self.disable_adapters or self.merged:
return result
else:
if autocast_dtype is None and x.device.type == "cpu":
expected_dtype = result.dtype
for active_adapter in self.active_adapters:
if active_adapter not in self.lora_A.keys():
continue
x = x.to(self.lora_A[active_adapter].weight.dtype)
lora_A = self.lora_A[active_adapter]
lora_B = self.lora_B[active_adapter]
dropout = self.lora_dropout[active_adapter]
scaling = self.scaling[active_adapter]
result += lora_B(lora_A(dropout(self.qa_pool(x)))).to(expected_dtype) * scaling
else:
for active_adapter in self.active_adapters:
if active_adapter not in self.lora_A.keys():
continue
lora_A = self.lora_A[active_adapter]
lora_B = self.lora_B[active_adapter]
dropout = self.lora_dropout[active_adapter]
scaling = self.scaling[active_adapter]
result += lora_B(lora_A(dropout(self.qa_pool(x)))) * scaling
return result
class LoraBF16Linear(Module, LoraLayer):
# Lora implemented in a dense layer
def __init__(
self,
base_layer,
adapter_name,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
# Set this to True if the layer to replace stores weight like (fan_in, fan_out)
fan_in_fan_out: bool = False,
is_target_conv_1d_layer: bool = False,
init_lora_weights: Union[bool, str]=True,
use_rslora: bool = False,
use_dora: bool = False,
**kwargs,
):
super().__init__()
LoraLayer.__init__(self, base_layer, **kwargs)
self.fan_in_fan_out = fan_in_fan_out
self._active_adapter = adapter_name
self.update_layer(
adapter_name,
r,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
init_lora_weights=init_lora_weights,
use_rslora=use_rslora,
use_dora=use_dora,
)
self.is_target_conv_1d_layer = is_target_conv_1d_layer
def forward(self, x: torch.Tensor):
autocast_dtype = get_autocast_dtype(x)
if x.device.type == "xpu":
# force to use bf16 on gpu
x = x.to(torch.bfloat16)
elif autocast_dtype is not None:
x = x.to(autocast_dtype)
result = self.base_layer.forward(x)
if self.disable_adapters or self.merged:
return result
else:
if autocast_dtype is None and x.device.type == "cpu":
expected_dtype = result.dtype
for active_adapter in self.active_adapters:
if active_adapter not in self.lora_A.keys():
continue
x = x.to(self.lora_A[active_adapter].weight.dtype)
lora_A = self.lora_A[active_adapter]
lora_B = self.lora_B[active_adapter]
dropout = self.lora_dropout[active_adapter]
scaling = self.scaling[active_adapter]
result += lora_B(lora_A(dropout(x))).to(expected_dtype) * scaling
else:
for active_adapter in self.active_adapters:
if active_adapter not in self.lora_A.keys():
continue
lora_A = self.lora_A[active_adapter]
lora_B = self.lora_B[active_adapter]
dropout = self.lora_dropout[active_adapter]
scaling = self.scaling[active_adapter]
result += lora_B(lora_A(dropout(x))) * scaling
return result
def _create_new_module(create_new_module_func, lora_config, adapter_name, target, **kwargs):
if isinstance(target, LowBitLinear) or isinstance(target, BF16Linear):
low_bit_kwargs = kwargs.copy()
bias = low_bit_kwargs.pop("bias", False)
if hasattr(lora_config, "training_mode") and lora_config.training_mode == "lora":
new_module = LoraBF16Linear(target,
adapter_name,
bias=bias,
**low_bit_kwargs)
else:
if hasattr(lora_config, "training_mode"):
qa_lora = lora_config.training_mode == "qalora"
else:
qa_lora = False
low_bit_kwargs.update(
{
"qtype": target.qtype,
"qa_lora": qa_lora
}
)
new_module = LoraLowBitLinear(target,
adapter_name,
bias=bias,
**low_bit_kwargs)
else:
new_module = create_new_module_func(lora_config, adapter_name, target, **kwargs)
return new_module
from peft.tuners.lora import LoraModel
from peft.tuners.lora import LoraConfig as LoraConfigBase
from transformers import TrainingArguments as TrainingArgumentsBase
from transformers.training_args import OptimizerNames
from dataclasses import dataclass, field
@dataclass
class LoraConfig(LoraConfigBase):
training_mode: str = field(default="qlora", metadata={"help": "determine training mode"})
def __init__(self, *args, **kwargs):
self.training_mode = kwargs.pop("training_mode", "qlora")
super().__init__(*args, **kwargs)
from ipex_llm.llm_patching import bigdl_patched
if bigdl_patched == 'Train':
from .model import patched_training_mode
self.training_mode = patched_training_mode
supported_optim = ["adamw_hf", "adamw_torch", "adafactor", "sgd", "adagrad", "rmsprop"]
@dataclass
class TrainingArguments(TrainingArgumentsBase):
def __init__(self, *args, **kwargs):
kwargs["fp16"] = False
kwargs["bf16"] = True
for optim in supported_optim.copy():
supported_optim.append(OptimizerNames(optim))
if kwargs["optim"] not in supported_optim:
LOG.info(f"{self.optim} is not supported yet and adamw_torch optimizer is used.")
kwargs["optim"] = "adamw_torch"
super().__init__(*args, **kwargs)
def get_peft_model(*args, **kwargs):
old_create_new_module = LoraModel._create_new_module
LoraModel._create_new_module = staticmethod(functools.partial(_create_new_module,
old_create_new_module))
try:
from ipex_llm.llm_patching import bigdl_patched
if bigdl_patched == 'Train':
from peft import get_peft_model_original
else:
from peft import get_peft_model as get_peft_model_original
model = get_peft_model_original(*args, **kwargs)
finally:
LoraModel._create_new_module = old_create_new_module
if model.device.type == "xpu":
cast_lora_weight(model, torch.bfloat16)
_optimize_post(model)
torch.xpu.synchronize()
return model
def prepare_model_for_kbit_training(model, use_gradient_checkpointing=True):
r"""
This method wraps the entire protocol for preparing a model before running a training.
This includes:
1- Cast the layernorm in fp32
2- making output embedding layer require grads
3- Add the upcasting of the lm head to fp32
Args:
model, (`transformers.PreTrainedModel`):
The loaded model from `transformers`
"""
is_gptq_quantized = getattr(model, "quantization_method", None) == "gptq"
for name, param in model.named_parameters():
# freeze base model's layers
param.requires_grad = False
if not is_gptq_quantized:
# cast all non INT8 parameters to fp32
# for param in model.parameters():
# if (param.dtype == torch.float16) or (param.dtype == torch.bfloat16):
# param.data = param.data.to(torch.float32)
# change to below way to reduce memory for Linear
# otherwise lora finetuning on arc may OOM at this convert
for module in model.modules():
if list(module.children()) == []:
# leaf module
if not isinstance(module, (Linear, Embedding)):
for param in module.parameters():
if (param.dtype == torch.float16) or (param.dtype == torch.bfloat16):
param.data = param.data.to(torch.float32)
if use_gradient_checkpointing:
# For backward compatibility
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
else:
def make_inputs_require_grad(module, input, output):
output.requires_grad_(True)
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
# enable gradient checkpointing for memory efficiency
model.gradient_checkpointing_enable()
return model
class PeftModel:
@staticmethod
def from_pretrained(*args,
**kwargs):
old_create_new_module = LoraModel._create_new_module
LoraModel._create_new_module = staticmethod(functools.partial(_create_new_module,
old_create_new_module))
from peft import PeftModel
try:
model = PeftModel.from_pretrained(*args, **kwargs)
finally:
LoraModel._create_new_module = old_create_new_module
return model
from peft.mapping import PEFT_TYPE_TO_CONFIG_MAPPING
PEFT_TYPE_TO_CONFIG_MAPPING["lora"] = LoraConfig
def cast_lora_weight(model, dtype=torch.bfloat16):
for name, module in model.named_modules():
if isinstance(module, LowBitLinear):
module.compute_dtype = dtype
if isinstance(module, LoraLayer):
module = module.to(dtype)
if isinstance(module, BF16Linear):
module = module.to(dtype)
module.compute_dtype = dtype
if 'norm' in name:
module = module.to(torch.float32)
if 'lm_head' in name or 'embed_tokens' in name:
if hasattr(module, 'weight'):
if module.weight.dtype == torch.float32:
module = module.to(dtype)
def _optimize_post(model):
import transformers
from packaging import version
from ipex_llm.transformers.convert import convert_forward
from ipex_llm.transformers.models.llama import llama_attention_fast_forward
trans_version = transformers.__version__
if version.parse(trans_version) >= version.parse("4.31.0"):
LOG.info("Optimizing Llama finetuning....")
convert_forward(
model,
transformers.models.llama.modeling_llama.LlamaAttention,
llama_attention_fast_forward,)