ipex-llm/python/llm/src/ipex_llm/llm_patching.py
Wang, Jian4 9df70d95eb
Refactor bigdl.llm to ipex_llm (#24)
* Rename bigdl/llm to ipex_llm

* rm python/llm/src/bigdl

* from bigdl.llm to from ipex_llm
2024-03-22 15:41:21 +08:00

79 lines
2.8 KiB
Python

#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import transformers
import importlib
import sys
from ipex_llm.utils.common import invalidInputError
from enum import Enum
bigdl_patched = None # None or 'Train' or 'Inference'
attrs = []
def replace_attr(obj, name: str, value):
original_attr = getattr(obj, name)
setattr(obj, name, value)
attrs.append((obj, name, original_attr))
def llm_patch(train=False):
'''
llm_patch is used to make users' LLM application benefit from BigDL-LLM optimization
with only one-line code patch.
:param train: Whether to apply bigdl-llm patch for training code, default to be `False`.
'''
global bigdl_patched
if bigdl_patched:
return
# Initial version of patch for llm finetuning, inference support TBD
if train:
from ipex_llm.transformers import AutoModelForCausalLM, AutoModel
replace_attr(transformers, "AutoModelForCausalLM", AutoModelForCausalLM)
replace_attr(transformers, "LlamaForCausalLM", AutoModelForCausalLM)
replace_attr(transformers, "AutoModel", AutoModel)
import_peft_check = 'peft' in sys.modules or 'peft.utils' in sys.modules or \
'peft.tuners' in sys.modules or 'peft.mapping' in sys.modules
invalidInputError(not import_peft_check,
'llm_patch() should be called at the beginning of your code.')
import peft
from ipex_llm.transformers.qlora import get_peft_model, prepare_model_for_kbit_training,\
LoraConfig, TrainingArguments
replace_attr(transformers, "TrainingArguments", TrainingArguments)
get_peft_model_original = getattr(peft, "get_peft_model")
replace_attr(peft, "get_peft_model", get_peft_model)
setattr(peft, "get_peft_model_original", get_peft_model_original)
replace_attr(peft, "prepare_model_for_kbit_training", prepare_model_for_kbit_training)
replace_attr(peft, "prepare_model_for_int8_training", prepare_model_for_kbit_training)
replace_attr(peft, "LoraConfig", LoraConfig)
bigdl_patched = 'Train'
def llm_unpatch():
'''
llm_unpatch is an reverse function to llm_patch.
'''
global bigdl_patched
if bigdl_patched is None:
return
for obj, name, torch_attr in attrs:
setattr(obj, name, torch_attr)
bigdl_patched = None