From 96fd26759c55a7cd74ff41a922f34ec868b67827 Mon Sep 17 00:00:00 2001 From: binbin Deng <108676127+plusbang@users.noreply.github.com> Date: Mon, 20 Nov 2023 14:31:24 +0800 Subject: [PATCH] LLM: fix QLoRA finetuning example on CPU (#9489) --- .../example/CPU/QLoRA-FineTuning/qlora_finetuning_cpu.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/python/llm/example/CPU/QLoRA-FineTuning/qlora_finetuning_cpu.py b/python/llm/example/CPU/QLoRA-FineTuning/qlora_finetuning_cpu.py index 2863251e..617ae25d 100644 --- a/python/llm/example/CPU/QLoRA-FineTuning/qlora_finetuning_cpu.py +++ b/python/llm/example/CPU/QLoRA-FineTuning/qlora_finetuning_cpu.py @@ -25,6 +25,7 @@ from bigdl.llm.transformers.qlora import get_peft_model, prepare_model_for_kbit_ from bigdl.llm.transformers import AutoModelForCausalLM from datasets import load_dataset import argparse +from bigdl.llm.utils.isa_checker import ISAChecker if __name__ == "__main__": parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for Llama2 model') @@ -63,6 +64,11 @@ if __name__ == "__main__": model = get_peft_model(model, config) tokenizer.pad_token_id = 0 tokenizer.padding_side = "left" + + # To avoid only one core is used on client CPU + isa_checker = ISAChecker() + bf16_flag = isa_checker.check_avx512() + trainer = transformers.Trainer( model=model, train_dataset=data["train"], @@ -73,7 +79,7 @@ if __name__ == "__main__": max_steps=200, learning_rate=2e-4, save_steps=100, - bf16=True, + bf16=bf16_flag, logging_steps=20, output_dir="outputs", optim="adamw_hf", # paged_adamw_8bit is not supported yet