LLM: fix QLoRA finetuning example on CPU (#9489)

This commit is contained in:
binbin Deng 2023-11-20 14:31:24 +08:00 committed by GitHub
parent 0f9a440b06
commit 96fd26759c

View file

@ -25,6 +25,7 @@ from bigdl.llm.transformers.qlora import get_peft_model, prepare_model_for_kbit_
from bigdl.llm.transformers import AutoModelForCausalLM from bigdl.llm.transformers import AutoModelForCausalLM
from datasets import load_dataset from datasets import load_dataset
import argparse import argparse
from bigdl.llm.utils.isa_checker import ISAChecker
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for Llama2 model') parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for Llama2 model')
@ -63,6 +64,11 @@ if __name__ == "__main__":
model = get_peft_model(model, config) model = get_peft_model(model, config)
tokenizer.pad_token_id = 0 tokenizer.pad_token_id = 0
tokenizer.padding_side = "left" tokenizer.padding_side = "left"
# To avoid only one core is used on client CPU
isa_checker = ISAChecker()
bf16_flag = isa_checker.check_avx512()
trainer = transformers.Trainer( trainer = transformers.Trainer(
model=model, model=model,
train_dataset=data["train"], train_dataset=data["train"],
@ -73,7 +79,7 @@ if __name__ == "__main__":
max_steps=200, max_steps=200,
learning_rate=2e-4, learning_rate=2e-4,
save_steps=100, save_steps=100,
bf16=True, bf16=bf16_flag,
logging_steps=20, logging_steps=20,
output_dir="outputs", output_dir="outputs",
optim="adamw_hf", # paged_adamw_8bit is not supported yet optim="adamw_hf", # paged_adamw_8bit is not supported yet