Support peft LoraConfig (#9636)
* support peft loraconfig * use testcase to test * fix style * meet comments
This commit is contained in:
parent
0b6f29a7fc
commit
70f5e7bf0d
2 changed files with 4 additions and 3 deletions
|
|
@ -20,8 +20,8 @@ import os
|
|||
import transformers
|
||||
from transformers import LlamaTokenizer
|
||||
import intel_extension_for_pytorch as ipex
|
||||
from bigdl.llm.transformers.qlora import get_peft_model, prepare_model_for_kbit_training,\
|
||||
LoraConfig
|
||||
from peft import LoraConfig
|
||||
from bigdl.llm.transformers.qlora import get_peft_model, prepare_model_for_kbit_training
|
||||
from bigdl.llm.transformers import AutoModelForCausalLM
|
||||
from datasets import load_dataset
|
||||
import argparse
|
||||
|
|
|
|||
|
|
@ -136,7 +136,7 @@ def _create_new_module(create_new_module_func, lora_config, adapter_name, target
|
|||
low_bit_kwargs.update(
|
||||
{
|
||||
"qtype": target.qtype,
|
||||
"qa_lora": lora_config.qa_lora,
|
||||
"qa_lora": lora_config.qa_lora if hasattr(lora_config, "qa_lora") else False,
|
||||
}
|
||||
)
|
||||
new_module = LoraLowBitLinear(adapter_name,
|
||||
|
|
@ -165,6 +165,7 @@ def get_peft_model(*args, **kwargs):
|
|||
old_create_new_module = LoraModel._create_new_module
|
||||
LoraModel._create_new_module = staticmethod(functools.partial(_create_new_module,
|
||||
old_create_new_module))
|
||||
|
||||
try:
|
||||
from peft import get_peft_model as get_peft_model_original
|
||||
model = get_peft_model_original(*args, **kwargs)
|
||||
|
|
|
|||
Loading…
Reference in a new issue