Support peft LoraConfig (#9636)
* support peft loraconfig * use testcase to test * fix style * meet comments
This commit is contained in:
		
							parent
							
								
									0b6f29a7fc
								
							
						
					
					
						commit
						70f5e7bf0d
					
				
					 2 changed files with 4 additions and 3 deletions
				
			
		| 
						 | 
				
			
			@ -20,8 +20,8 @@ import os
 | 
			
		|||
import transformers
 | 
			
		||||
from transformers import LlamaTokenizer
 | 
			
		||||
import intel_extension_for_pytorch as ipex
 | 
			
		||||
from bigdl.llm.transformers.qlora import get_peft_model, prepare_model_for_kbit_training,\
 | 
			
		||||
    LoraConfig
 | 
			
		||||
from peft import LoraConfig
 | 
			
		||||
from bigdl.llm.transformers.qlora import get_peft_model, prepare_model_for_kbit_training
 | 
			
		||||
from bigdl.llm.transformers import AutoModelForCausalLM
 | 
			
		||||
from datasets import load_dataset
 | 
			
		||||
import argparse
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -136,7 +136,7 @@ def _create_new_module(create_new_module_func, lora_config, adapter_name, target
 | 
			
		|||
        low_bit_kwargs.update(
 | 
			
		||||
            {
 | 
			
		||||
                "qtype": target.qtype,
 | 
			
		||||
                "qa_lora": lora_config.qa_lora,
 | 
			
		||||
                "qa_lora": lora_config.qa_lora if hasattr(lora_config, "qa_lora") else False,
 | 
			
		||||
            }
 | 
			
		||||
        )
 | 
			
		||||
        new_module = LoraLowBitLinear(adapter_name,
 | 
			
		||||
| 
						 | 
				
			
			@ -165,6 +165,7 @@ def get_peft_model(*args, **kwargs):
 | 
			
		|||
    old_create_new_module = LoraModel._create_new_module
 | 
			
		||||
    LoraModel._create_new_module = staticmethod(functools.partial(_create_new_module,
 | 
			
		||||
                                                                  old_create_new_module))
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
        from peft import get_peft_model as get_peft_model_original
 | 
			
		||||
        model = get_peft_model_original(*args, **kwargs)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue