Support peft LoraConfig (#9636)
* support peft loraconfig * use testcase to test * fix style * meet comments
This commit is contained in:
		
							parent
							
								
									0b6f29a7fc
								
							
						
					
					
						commit
						70f5e7bf0d
					
				
					 2 changed files with 4 additions and 3 deletions
				
			
		| 
						 | 
					@ -20,8 +20,8 @@ import os
 | 
				
			||||||
import transformers
 | 
					import transformers
 | 
				
			||||||
from transformers import LlamaTokenizer
 | 
					from transformers import LlamaTokenizer
 | 
				
			||||||
import intel_extension_for_pytorch as ipex
 | 
					import intel_extension_for_pytorch as ipex
 | 
				
			||||||
from bigdl.llm.transformers.qlora import get_peft_model, prepare_model_for_kbit_training,\
 | 
					from peft import LoraConfig
 | 
				
			||||||
    LoraConfig
 | 
					from bigdl.llm.transformers.qlora import get_peft_model, prepare_model_for_kbit_training
 | 
				
			||||||
from bigdl.llm.transformers import AutoModelForCausalLM
 | 
					from bigdl.llm.transformers import AutoModelForCausalLM
 | 
				
			||||||
from datasets import load_dataset
 | 
					from datasets import load_dataset
 | 
				
			||||||
import argparse
 | 
					import argparse
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -136,7 +136,7 @@ def _create_new_module(create_new_module_func, lora_config, adapter_name, target
 | 
				
			||||||
        low_bit_kwargs.update(
 | 
					        low_bit_kwargs.update(
 | 
				
			||||||
            {
 | 
					            {
 | 
				
			||||||
                "qtype": target.qtype,
 | 
					                "qtype": target.qtype,
 | 
				
			||||||
                "qa_lora": lora_config.qa_lora,
 | 
					                "qa_lora": lora_config.qa_lora if hasattr(lora_config, "qa_lora") else False,
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        new_module = LoraLowBitLinear(adapter_name,
 | 
					        new_module = LoraLowBitLinear(adapter_name,
 | 
				
			||||||
| 
						 | 
					@ -165,6 +165,7 @@ def get_peft_model(*args, **kwargs):
 | 
				
			||||||
    old_create_new_module = LoraModel._create_new_module
 | 
					    old_create_new_module = LoraModel._create_new_module
 | 
				
			||||||
    LoraModel._create_new_module = staticmethod(functools.partial(_create_new_module,
 | 
					    LoraModel._create_new_module = staticmethod(functools.partial(_create_new_module,
 | 
				
			||||||
                                                                  old_create_new_module))
 | 
					                                                                  old_create_new_module))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    try:
 | 
					    try:
 | 
				
			||||||
        from peft import get_peft_model as get_peft_model_original
 | 
					        from peft import get_peft_model as get_peft_model_original
 | 
				
			||||||
        model = get_peft_model_original(*args, **kwargs)
 | 
					        model = get_peft_model_original(*args, **kwargs)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue