Fix error during merging adapter (#11145)
This commit is contained in:
parent
daf7b1cd56
commit
c9168b85b7
1 changed files with 66 additions and 0 deletions
|
|
@ -47,6 +47,22 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# Some parts of this file is adapted from
|
||||
# https://github.com/huggingface/peft/blob/v0.5.0/src/peft/config.py
|
||||
# Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
def patch_prepare_ipex(self, *args):
|
||||
|
|
@ -58,6 +74,7 @@ from transformers.utils import (
|
|||
is_sagemaker_mp_enabled,
|
||||
is_accelerate_available,
|
||||
is_torch_xpu_available,
|
||||
is_peft_available,
|
||||
is_sagemaker_dp_enabled,
|
||||
is_torch_tpu_available,
|
||||
is_torch_npu_available)
|
||||
|
|
@ -69,6 +86,8 @@ import torch.distributed as dist
|
|||
import os
|
||||
import warnings
|
||||
from datetime import timedelta
|
||||
from huggingface_hub import hf_hub_download
|
||||
from ipex_llm.utils.common import invalidInputError
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate.state import AcceleratorState, PartialState
|
||||
|
|
@ -196,3 +215,50 @@ Accelerator._prepare_ipex = patch_prepare_ipex
|
|||
# patch transformer for xpu DDP traing
|
||||
from transformers import TrainingArguments
|
||||
TrainingArguments._setup_devices = _setup_devices
|
||||
|
||||
CONFIG_NAME = "adapter_config.json"
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, subfolder=None, **kwargs):
|
||||
# Avoid circular dependency .. TODO: fix this with a larger refactor
|
||||
from peft.mapping import PEFT_TYPE_TO_CONFIG_MAPPING
|
||||
|
||||
path = (
|
||||
os.path.join(pretrained_model_name_or_path, subfolder)
|
||||
if subfolder is not None
|
||||
else pretrained_model_name_or_path
|
||||
)
|
||||
|
||||
hf_hub_download_kwargs, class_kwargs, _ = cls._split_kwargs(kwargs)
|
||||
|
||||
if os.path.isfile(os.path.join(path, CONFIG_NAME)):
|
||||
config_file = os.path.join(path, CONFIG_NAME)
|
||||
else:
|
||||
try:
|
||||
config_file = hf_hub_download(pretrained_model_name_or_path, CONFIG_NAME,
|
||||
subfolder=subfolder, **hf_hub_download_kwargs)
|
||||
except Exception:
|
||||
invalidInputError(False,
|
||||
f"Can't find '{CONFIG_NAME}' at '{pretrained_model_name_or_path}'")
|
||||
|
||||
loaded_attributes = cls.from_json_file(config_file)
|
||||
|
||||
if "peft_type" in loaded_attributes:
|
||||
peft_type = loaded_attributes["peft_type"]
|
||||
config_cls = PEFT_TYPE_TO_CONFIG_MAPPING[peft_type]
|
||||
else:
|
||||
config_cls = cls
|
||||
|
||||
config = config_cls(**class_kwargs)
|
||||
|
||||
for key, value in loaded_attributes.items():
|
||||
if hasattr(config, key):
|
||||
setattr(config, key, value)
|
||||
|
||||
return config
|
||||
|
||||
# patch peft for merging adapter into the original model
|
||||
if is_peft_available():
|
||||
from peft.config import PeftConfigMixin
|
||||
PeftConfigMixin.from_pretrained = from_pretrained
|
||||
|
|
|
|||
Loading…
Reference in a new issue