LLM: fix QLoRA finetuning example on CPU (#9489)
This commit is contained in:
parent
0f9a440b06
commit
96fd26759c
1 changed files with 7 additions and 1 deletions
|
|
@ -25,6 +25,7 @@ from bigdl.llm.transformers.qlora import get_peft_model, prepare_model_for_kbit_
|
||||||
from bigdl.llm.transformers import AutoModelForCausalLM
|
from bigdl.llm.transformers import AutoModelForCausalLM
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
import argparse
|
import argparse
|
||||||
|
from bigdl.llm.utils.isa_checker import ISAChecker
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for Llama2 model')
|
parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for Llama2 model')
|
||||||
|
|
@ -63,6 +64,11 @@ if __name__ == "__main__":
|
||||||
model = get_peft_model(model, config)
|
model = get_peft_model(model, config)
|
||||||
tokenizer.pad_token_id = 0
|
tokenizer.pad_token_id = 0
|
||||||
tokenizer.padding_side = "left"
|
tokenizer.padding_side = "left"
|
||||||
|
|
||||||
|
# To avoid only one core is used on client CPU
|
||||||
|
isa_checker = ISAChecker()
|
||||||
|
bf16_flag = isa_checker.check_avx512()
|
||||||
|
|
||||||
trainer = transformers.Trainer(
|
trainer = transformers.Trainer(
|
||||||
model=model,
|
model=model,
|
||||||
train_dataset=data["train"],
|
train_dataset=data["train"],
|
||||||
|
|
@ -73,7 +79,7 @@ if __name__ == "__main__":
|
||||||
max_steps=200,
|
max_steps=200,
|
||||||
learning_rate=2e-4,
|
learning_rate=2e-4,
|
||||||
save_steps=100,
|
save_steps=100,
|
||||||
bf16=True,
|
bf16=bf16_flag,
|
||||||
logging_steps=20,
|
logging_steps=20,
|
||||||
output_dir="outputs",
|
output_dir="outputs",
|
||||||
optim="adamw_hf", # paged_adamw_8bit is not supported yet
|
optim="adamw_hf", # paged_adamw_8bit is not supported yet
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue