Support peft LoraConfig (#9636)

* support peft loraconfig

* use testcase to test

* fix style

* meet comments
This commit is contained in:
Yina Chen 2023-12-08 16:13:03 +08:00 committed by GitHub
parent 0b6f29a7fc
commit 70f5e7bf0d
2 changed files with 4 additions and 3 deletions

View file

@ -20,8 +20,8 @@ import os
import transformers import transformers
from transformers import LlamaTokenizer from transformers import LlamaTokenizer
import intel_extension_for_pytorch as ipex import intel_extension_for_pytorch as ipex
from bigdl.llm.transformers.qlora import get_peft_model, prepare_model_for_kbit_training,\ from peft import LoraConfig
LoraConfig from bigdl.llm.transformers.qlora import get_peft_model, prepare_model_for_kbit_training
from bigdl.llm.transformers import AutoModelForCausalLM from bigdl.llm.transformers import AutoModelForCausalLM
from datasets import load_dataset from datasets import load_dataset
import argparse import argparse

View file

@ -136,7 +136,7 @@ def _create_new_module(create_new_module_func, lora_config, adapter_name, target
low_bit_kwargs.update( low_bit_kwargs.update(
{ {
"qtype": target.qtype, "qtype": target.qtype,
"qa_lora": lora_config.qa_lora, "qa_lora": lora_config.qa_lora if hasattr(lora_config, "qa_lora") else False,
} }
) )
new_module = LoraLowBitLinear(adapter_name, new_module = LoraLowBitLinear(adapter_name,
@ -165,6 +165,7 @@ def get_peft_model(*args, **kwargs):
old_create_new_module = LoraModel._create_new_module old_create_new_module = LoraModel._create_new_module
LoraModel._create_new_module = staticmethod(functools.partial(_create_new_module, LoraModel._create_new_module = staticmethod(functools.partial(_create_new_module,
old_create_new_module)) old_create_new_module))
try: try:
from peft import get_peft_model as get_peft_model_original from peft import get_peft_model as get_peft_model_original
model = get_peft_model_original(*args, **kwargs) model = get_peft_model_original(*args, **kwargs)