Fix error during merging adapter (#11145)

This commit is contained in:
binbin Deng 2024-05-27 19:41:42 +08:00 committed by GitHub
parent daf7b1cd56
commit c9168b85b7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -47,6 +47,22 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Some parts of this file is adapted from
# https://github.com/huggingface/peft/blob/v0.5.0/src/peft/config.py
# Copyright [yyyy] [name of copyright owner]
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
def patch_prepare_ipex(self, *args):
@ -58,6 +74,7 @@ from transformers.utils import (
is_sagemaker_mp_enabled,
is_accelerate_available,
is_torch_xpu_available,
is_peft_available,
is_sagemaker_dp_enabled,
is_torch_tpu_available,
is_torch_npu_available)
@ -69,6 +86,8 @@ import torch.distributed as dist
import os
import warnings
from datetime import timedelta
from huggingface_hub import hf_hub_download
from ipex_llm.utils.common import invalidInputError
if is_accelerate_available():
from accelerate.state import AcceleratorState, PartialState
@ -196,3 +215,50 @@ Accelerator._prepare_ipex = patch_prepare_ipex
# patch transformer for xpu DDP traing
from transformers import TrainingArguments
TrainingArguments._setup_devices = _setup_devices
CONFIG_NAME = "adapter_config.json"
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, subfolder=None, **kwargs):
# Avoid circular dependency .. TODO: fix this with a larger refactor
from peft.mapping import PEFT_TYPE_TO_CONFIG_MAPPING
path = (
os.path.join(pretrained_model_name_or_path, subfolder)
if subfolder is not None
else pretrained_model_name_or_path
)
hf_hub_download_kwargs, class_kwargs, _ = cls._split_kwargs(kwargs)
if os.path.isfile(os.path.join(path, CONFIG_NAME)):
config_file = os.path.join(path, CONFIG_NAME)
else:
try:
config_file = hf_hub_download(pretrained_model_name_or_path, CONFIG_NAME,
subfolder=subfolder, **hf_hub_download_kwargs)
except Exception:
invalidInputError(False,
f"Can't find '{CONFIG_NAME}' at '{pretrained_model_name_or_path}'")
loaded_attributes = cls.from_json_file(config_file)
if "peft_type" in loaded_attributes:
peft_type = loaded_attributes["peft_type"]
config_cls = PEFT_TYPE_TO_CONFIG_MAPPING[peft_type]
else:
config_cls = cls
config = config_cls(**class_kwargs)
for key, value in loaded_attributes.items():
if hasattr(config, key):
setattr(config, key, value)
return config
# patch peft for merging adapter into the original model
if is_peft_available():
from peft.config import PeftConfigMixin
PeftConfigMixin.from_pretrained = from_pretrained