Fix error during merging adapter (#11145)
This commit is contained in:
		
							parent
							
								
									daf7b1cd56
								
							
						
					
					
						commit
						c9168b85b7
					
				
					 1 changed files with 66 additions and 0 deletions
				
			
		| 
						 | 
				
			
			@ -47,6 +47,22 @@
 | 
			
		|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
#
 | 
			
		||||
# Some parts of this file is adapted from
 | 
			
		||||
# https://github.com/huggingface/peft/blob/v0.5.0/src/peft/config.py
 | 
			
		||||
# Copyright [yyyy] [name of copyright owner]
 | 
			
		||||
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def patch_prepare_ipex(self, *args):
 | 
			
		||||
| 
						 | 
				
			
			@ -58,6 +74,7 @@ from transformers.utils import (
 | 
			
		|||
    is_sagemaker_mp_enabled,
 | 
			
		||||
    is_accelerate_available,
 | 
			
		||||
    is_torch_xpu_available,
 | 
			
		||||
    is_peft_available,
 | 
			
		||||
    is_sagemaker_dp_enabled,
 | 
			
		||||
    is_torch_tpu_available,
 | 
			
		||||
    is_torch_npu_available)
 | 
			
		||||
| 
						 | 
				
			
			@ -69,6 +86,8 @@ import torch.distributed as dist
 | 
			
		|||
import os
 | 
			
		||||
import warnings
 | 
			
		||||
from datetime import timedelta
 | 
			
		||||
from huggingface_hub import hf_hub_download
 | 
			
		||||
from ipex_llm.utils.common import invalidInputError
 | 
			
		||||
 | 
			
		||||
if is_accelerate_available():
 | 
			
		||||
    from accelerate.state import AcceleratorState, PartialState
 | 
			
		||||
| 
						 | 
				
			
			@ -196,3 +215,50 @@ Accelerator._prepare_ipex = patch_prepare_ipex
 | 
			
		|||
# patch transformer for xpu DDP traing
 | 
			
		||||
from transformers import TrainingArguments
 | 
			
		||||
TrainingArguments._setup_devices = _setup_devices
 | 
			
		||||
 | 
			
		||||
CONFIG_NAME = "adapter_config.json"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@classmethod
 | 
			
		||||
def from_pretrained(cls, pretrained_model_name_or_path, subfolder=None, **kwargs):
 | 
			
		||||
    # Avoid circular dependency .. TODO: fix this with a larger refactor
 | 
			
		||||
    from peft.mapping import PEFT_TYPE_TO_CONFIG_MAPPING
 | 
			
		||||
 | 
			
		||||
    path = (
 | 
			
		||||
        os.path.join(pretrained_model_name_or_path, subfolder)
 | 
			
		||||
        if subfolder is not None
 | 
			
		||||
        else pretrained_model_name_or_path
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    hf_hub_download_kwargs, class_kwargs, _ = cls._split_kwargs(kwargs)
 | 
			
		||||
 | 
			
		||||
    if os.path.isfile(os.path.join(path, CONFIG_NAME)):
 | 
			
		||||
        config_file = os.path.join(path, CONFIG_NAME)
 | 
			
		||||
    else:
 | 
			
		||||
        try:
 | 
			
		||||
            config_file = hf_hub_download(pretrained_model_name_or_path, CONFIG_NAME,
 | 
			
		||||
                                          subfolder=subfolder, **hf_hub_download_kwargs)
 | 
			
		||||
        except Exception:
 | 
			
		||||
            invalidInputError(False,
 | 
			
		||||
                              f"Can't find '{CONFIG_NAME}' at '{pretrained_model_name_or_path}'")
 | 
			
		||||
 | 
			
		||||
    loaded_attributes = cls.from_json_file(config_file)
 | 
			
		||||
 | 
			
		||||
    if "peft_type" in loaded_attributes:
 | 
			
		||||
        peft_type = loaded_attributes["peft_type"]
 | 
			
		||||
        config_cls = PEFT_TYPE_TO_CONFIG_MAPPING[peft_type]
 | 
			
		||||
    else:
 | 
			
		||||
        config_cls = cls
 | 
			
		||||
 | 
			
		||||
    config = config_cls(**class_kwargs)
 | 
			
		||||
 | 
			
		||||
    for key, value in loaded_attributes.items():
 | 
			
		||||
        if hasattr(config, key):
 | 
			
		||||
            setattr(config, key, value)
 | 
			
		||||
 | 
			
		||||
    return config
 | 
			
		||||
 | 
			
		||||
# patch peft for merging adapter into the original model
 | 
			
		||||
if is_peft_available():
 | 
			
		||||
    from peft.config import PeftConfigMixin
 | 
			
		||||
    PeftConfigMixin.from_pretrained = from_pretrained
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue