Fix error during merging adapter (#11145)
This commit is contained in:
parent
daf7b1cd56
commit
c9168b85b7
1 changed files with 66 additions and 0 deletions
|
|
@ -47,6 +47,22 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
# Some parts of this file is adapted from
|
||||||
|
# https://github.com/huggingface/peft/blob/v0.5.0/src/peft/config.py
|
||||||
|
# Copyright [yyyy] [name of copyright owner]
|
||||||
|
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
def patch_prepare_ipex(self, *args):
|
def patch_prepare_ipex(self, *args):
|
||||||
|
|
@ -58,6 +74,7 @@ from transformers.utils import (
|
||||||
is_sagemaker_mp_enabled,
|
is_sagemaker_mp_enabled,
|
||||||
is_accelerate_available,
|
is_accelerate_available,
|
||||||
is_torch_xpu_available,
|
is_torch_xpu_available,
|
||||||
|
is_peft_available,
|
||||||
is_sagemaker_dp_enabled,
|
is_sagemaker_dp_enabled,
|
||||||
is_torch_tpu_available,
|
is_torch_tpu_available,
|
||||||
is_torch_npu_available)
|
is_torch_npu_available)
|
||||||
|
|
@ -69,6 +86,8 @@ import torch.distributed as dist
|
||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
from ipex_llm.utils.common import invalidInputError
|
||||||
|
|
||||||
if is_accelerate_available():
|
if is_accelerate_available():
|
||||||
from accelerate.state import AcceleratorState, PartialState
|
from accelerate.state import AcceleratorState, PartialState
|
||||||
|
|
@ -196,3 +215,50 @@ Accelerator._prepare_ipex = patch_prepare_ipex
|
||||||
# patch transformer for xpu DDP traing
|
# patch transformer for xpu DDP traing
|
||||||
from transformers import TrainingArguments
|
from transformers import TrainingArguments
|
||||||
TrainingArguments._setup_devices = _setup_devices
|
TrainingArguments._setup_devices = _setup_devices
|
||||||
|
|
||||||
|
CONFIG_NAME = "adapter_config.json"
|
||||||
|
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, pretrained_model_name_or_path, subfolder=None, **kwargs):
|
||||||
|
# Avoid circular dependency .. TODO: fix this with a larger refactor
|
||||||
|
from peft.mapping import PEFT_TYPE_TO_CONFIG_MAPPING
|
||||||
|
|
||||||
|
path = (
|
||||||
|
os.path.join(pretrained_model_name_or_path, subfolder)
|
||||||
|
if subfolder is not None
|
||||||
|
else pretrained_model_name_or_path
|
||||||
|
)
|
||||||
|
|
||||||
|
hf_hub_download_kwargs, class_kwargs, _ = cls._split_kwargs(kwargs)
|
||||||
|
|
||||||
|
if os.path.isfile(os.path.join(path, CONFIG_NAME)):
|
||||||
|
config_file = os.path.join(path, CONFIG_NAME)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
config_file = hf_hub_download(pretrained_model_name_or_path, CONFIG_NAME,
|
||||||
|
subfolder=subfolder, **hf_hub_download_kwargs)
|
||||||
|
except Exception:
|
||||||
|
invalidInputError(False,
|
||||||
|
f"Can't find '{CONFIG_NAME}' at '{pretrained_model_name_or_path}'")
|
||||||
|
|
||||||
|
loaded_attributes = cls.from_json_file(config_file)
|
||||||
|
|
||||||
|
if "peft_type" in loaded_attributes:
|
||||||
|
peft_type = loaded_attributes["peft_type"]
|
||||||
|
config_cls = PEFT_TYPE_TO_CONFIG_MAPPING[peft_type]
|
||||||
|
else:
|
||||||
|
config_cls = cls
|
||||||
|
|
||||||
|
config = config_cls(**class_kwargs)
|
||||||
|
|
||||||
|
for key, value in loaded_attributes.items():
|
||||||
|
if hasattr(config, key):
|
||||||
|
setattr(config, key, value)
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
# patch peft for merging adapter into the original model
|
||||||
|
if is_peft_available():
|
||||||
|
from peft.config import PeftConfigMixin
|
||||||
|
PeftConfigMixin.from_pretrained = from_pretrained
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue