LLM: fix QLoRA finetuning example on CPU (#9489)
This commit is contained in:
		
							parent
							
								
									0f9a440b06
								
							
						
					
					
						commit
						96fd26759c
					
				
					 1 changed files with 7 additions and 1 deletions
				
			
		| 
						 | 
					@ -25,6 +25,7 @@ from bigdl.llm.transformers.qlora import get_peft_model, prepare_model_for_kbit_
 | 
				
			||||||
from bigdl.llm.transformers import AutoModelForCausalLM
 | 
					from bigdl.llm.transformers import AutoModelForCausalLM
 | 
				
			||||||
from datasets import load_dataset
 | 
					from datasets import load_dataset
 | 
				
			||||||
import argparse
 | 
					import argparse
 | 
				
			||||||
 | 
					from bigdl.llm.utils.isa_checker import ISAChecker
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if __name__ == "__main__":
 | 
					if __name__ == "__main__":
 | 
				
			||||||
    parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for Llama2 model')
 | 
					    parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for Llama2 model')
 | 
				
			||||||
| 
						 | 
					@ -63,6 +64,11 @@ if __name__ == "__main__":
 | 
				
			||||||
    model = get_peft_model(model, config)
 | 
					    model = get_peft_model(model, config)
 | 
				
			||||||
    tokenizer.pad_token_id = 0
 | 
					    tokenizer.pad_token_id = 0
 | 
				
			||||||
    tokenizer.padding_side = "left"
 | 
					    tokenizer.padding_side = "left"
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    # To avoid only one core is used on client CPU
 | 
				
			||||||
 | 
					    isa_checker = ISAChecker()
 | 
				
			||||||
 | 
					    bf16_flag = isa_checker.check_avx512()
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
    trainer = transformers.Trainer(
 | 
					    trainer = transformers.Trainer(
 | 
				
			||||||
        model=model,
 | 
					        model=model,
 | 
				
			||||||
        train_dataset=data["train"],
 | 
					        train_dataset=data["train"],
 | 
				
			||||||
| 
						 | 
					@ -73,7 +79,7 @@ if __name__ == "__main__":
 | 
				
			||||||
            max_steps=200,
 | 
					            max_steps=200,
 | 
				
			||||||
            learning_rate=2e-4,
 | 
					            learning_rate=2e-4,
 | 
				
			||||||
            save_steps=100,
 | 
					            save_steps=100,
 | 
				
			||||||
            bf16=True,
 | 
					            bf16=bf16_flag,
 | 
				
			||||||
            logging_steps=20,
 | 
					            logging_steps=20,
 | 
				
			||||||
            output_dir="outputs",
 | 
					            output_dir="outputs",
 | 
				
			||||||
            optim="adamw_hf",  # paged_adamw_8bit is not supported yet
 | 
					            optim="adamw_hf",  # paged_adamw_8bit is not supported yet
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue