Support peft LoraConfig (#9636)

* support peft loraconfig

* use testcase to test

* fix style

* meet comments
This commit is contained in:
Yina Chen 2023-12-08 16:13:03 +08:00 committed by GitHub
parent 0b6f29a7fc
commit 70f5e7bf0d
2 changed files with 4 additions and 3 deletions

View file

@ -20,8 +20,8 @@ import os
import transformers
from transformers import LlamaTokenizer
import intel_extension_for_pytorch as ipex
from bigdl.llm.transformers.qlora import get_peft_model, prepare_model_for_kbit_training,\
LoraConfig
from peft import LoraConfig
from bigdl.llm.transformers.qlora import get_peft_model, prepare_model_for_kbit_training
from bigdl.llm.transformers import AutoModelForCausalLM
from datasets import load_dataset
import argparse

View file

@ -136,7 +136,7 @@ def _create_new_module(create_new_module_func, lora_config, adapter_name, target
low_bit_kwargs.update(
{
"qtype": target.qtype,
"qa_lora": lora_config.qa_lora,
"qa_lora": lora_config.qa_lora if hasattr(lora_config, "qa_lora") else False,
}
)
new_module = LoraLowBitLinear(adapter_name,
@ -165,6 +165,7 @@ def get_peft_model(*args, **kwargs):
old_create_new_module = LoraModel._create_new_module
LoraModel._create_new_module = staticmethod(functools.partial(_create_new_module,
old_create_new_module))
try:
from peft import get_peft_model as get_peft_model_original
model = get_peft_model_original(*args, **kwargs)