LLM: fix QLoRA finetuning example on CPU (#9489)
This commit is contained in:
parent
0f9a440b06
commit
96fd26759c
1 changed files with 7 additions and 1 deletions
|
|
@ -25,6 +25,7 @@ from bigdl.llm.transformers.qlora import get_peft_model, prepare_model_for_kbit_
|
|||
from bigdl.llm.transformers import AutoModelForCausalLM
|
||||
from datasets import load_dataset
|
||||
import argparse
|
||||
from bigdl.llm.utils.isa_checker import ISAChecker
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for Llama2 model')
|
||||
|
|
@ -63,6 +64,11 @@ if __name__ == "__main__":
|
|||
model = get_peft_model(model, config)
|
||||
tokenizer.pad_token_id = 0
|
||||
tokenizer.padding_side = "left"
|
||||
|
||||
# To avoid only one core is used on client CPU
|
||||
isa_checker = ISAChecker()
|
||||
bf16_flag = isa_checker.check_avx512()
|
||||
|
||||
trainer = transformers.Trainer(
|
||||
model=model,
|
||||
train_dataset=data["train"],
|
||||
|
|
@ -73,7 +79,7 @@ if __name__ == "__main__":
|
|||
max_steps=200,
|
||||
learning_rate=2e-4,
|
||||
save_steps=100,
|
||||
bf16=True,
|
||||
bf16=bf16_flag,
|
||||
logging_steps=20,
|
||||
output_dir="outputs",
|
||||
optim="adamw_hf", # paged_adamw_8bit is not supported yet
|
||||
|
|
|
|||
Loading…
Reference in a new issue