From 70f5e7bf0d9eaf2c4c00a0198216cea3b1e3c6a3 Mon Sep 17 00:00:00 2001 From: Yina Chen <33650826+cyita@users.noreply.github.com> Date: Fri, 8 Dec 2023 16:13:03 +0800 Subject: [PATCH] Support peft LoraConfig (#9636) * support peft loraconfig * use testcase to test * fix style * meet comments --- python/llm/example/GPU/QLoRA-FineTuning/qlora_finetuning.py | 4 ++-- python/llm/src/bigdl/llm/transformers/qlora.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/python/llm/example/GPU/QLoRA-FineTuning/qlora_finetuning.py b/python/llm/example/GPU/QLoRA-FineTuning/qlora_finetuning.py index c2044f26..88dea463 100644 --- a/python/llm/example/GPU/QLoRA-FineTuning/qlora_finetuning.py +++ b/python/llm/example/GPU/QLoRA-FineTuning/qlora_finetuning.py @@ -20,8 +20,8 @@ import os import transformers from transformers import LlamaTokenizer import intel_extension_for_pytorch as ipex -from bigdl.llm.transformers.qlora import get_peft_model, prepare_model_for_kbit_training,\ - LoraConfig +from peft import LoraConfig +from bigdl.llm.transformers.qlora import get_peft_model, prepare_model_for_kbit_training from bigdl.llm.transformers import AutoModelForCausalLM from datasets import load_dataset import argparse diff --git a/python/llm/src/bigdl/llm/transformers/qlora.py b/python/llm/src/bigdl/llm/transformers/qlora.py index f507755f..0c55c840 100644 --- a/python/llm/src/bigdl/llm/transformers/qlora.py +++ b/python/llm/src/bigdl/llm/transformers/qlora.py @@ -136,7 +136,7 @@ def _create_new_module(create_new_module_func, lora_config, adapter_name, target low_bit_kwargs.update( { "qtype": target.qtype, - "qa_lora": lora_config.qa_lora, + "qa_lora": lora_config.qa_lora if hasattr(lora_config, "qa_lora") else False, } ) new_module = LoraLowBitLinear(adapter_name, @@ -165,6 +165,7 @@ def get_peft_model(*args, **kwargs): old_create_new_module = LoraModel._create_new_module LoraModel._create_new_module = staticmethod(functools.partial(_create_new_module, old_create_new_module)) + try: from peft import get_peft_model as get_peft_model_original model = get_peft_model_original(*args, **kwargs)