Support peft LoraConfig (#9636)
* support peft loraconfig * use testcase to test * fix style * meet comments
This commit is contained in:
parent
0b6f29a7fc
commit
70f5e7bf0d
2 changed files with 4 additions and 3 deletions
|
|
@ -20,8 +20,8 @@ import os
|
||||||
import transformers
|
import transformers
|
||||||
from transformers import LlamaTokenizer
|
from transformers import LlamaTokenizer
|
||||||
import intel_extension_for_pytorch as ipex
|
import intel_extension_for_pytorch as ipex
|
||||||
from bigdl.llm.transformers.qlora import get_peft_model, prepare_model_for_kbit_training,\
|
from peft import LoraConfig
|
||||||
LoraConfig
|
from bigdl.llm.transformers.qlora import get_peft_model, prepare_model_for_kbit_training
|
||||||
from bigdl.llm.transformers import AutoModelForCausalLM
|
from bigdl.llm.transformers import AutoModelForCausalLM
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
import argparse
|
import argparse
|
||||||
|
|
|
||||||
|
|
@ -136,7 +136,7 @@ def _create_new_module(create_new_module_func, lora_config, adapter_name, target
|
||||||
low_bit_kwargs.update(
|
low_bit_kwargs.update(
|
||||||
{
|
{
|
||||||
"qtype": target.qtype,
|
"qtype": target.qtype,
|
||||||
"qa_lora": lora_config.qa_lora,
|
"qa_lora": lora_config.qa_lora if hasattr(lora_config, "qa_lora") else False,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
new_module = LoraLowBitLinear(adapter_name,
|
new_module = LoraLowBitLinear(adapter_name,
|
||||||
|
|
@ -165,6 +165,7 @@ def get_peft_model(*args, **kwargs):
|
||||||
old_create_new_module = LoraModel._create_new_module
|
old_create_new_module = LoraModel._create_new_module
|
||||||
LoraModel._create_new_module = staticmethod(functools.partial(_create_new_module,
|
LoraModel._create_new_module = staticmethod(functools.partial(_create_new_module,
|
||||||
old_create_new_module))
|
old_create_new_module))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from peft import get_peft_model as get_peft_model_original
|
from peft import get_peft_model as get_peft_model_original
|
||||||
model = get_peft_model_original(*args, **kwargs)
|
model = get_peft_model_original(*args, **kwargs)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue