parent
dbbdb53a18
commit
0f82b8c3a0
2 changed files with 7 additions and 4 deletions
|
|
@ -17,7 +17,7 @@ conda activate llm
|
||||||
# below command will install intel_extension_for_pytorch==2.0.110+xpu as default
|
# below command will install intel_extension_for_pytorch==2.0.110+xpu as default
|
||||||
# you can install specific ipex/torch version for your need
|
# you can install specific ipex/torch version for your need
|
||||||
pip install --pre --upgrade bigdl-llm[xpu] -f https://developer.intel.com/ipex-whl-stable-xpu
|
pip install --pre --upgrade bigdl-llm[xpu] -f https://developer.intel.com/ipex-whl-stable-xpu
|
||||||
pip install transformers==4.34.0
|
pip install datasets transformers==4.34.0
|
||||||
pip install peft==0.5.0
|
pip install peft==0.5.0
|
||||||
pip install accelerate==0.23.0
|
pip install accelerate==0.23.0
|
||||||
```
|
```
|
||||||
|
|
|
||||||
|
|
@ -48,7 +48,9 @@ if __name__ == "__main__":
|
||||||
torch_dtype=torch.float16,
|
torch_dtype=torch.float16,
|
||||||
modules_to_not_convert=["lm_head"],)
|
modules_to_not_convert=["lm_head"],)
|
||||||
model = model.to('xpu')
|
model = model.to('xpu')
|
||||||
model.gradient_checkpointing_enable()
|
# Enable gradient_checkpointing if your memory is not enough,
|
||||||
|
# it will slowdown the training speed
|
||||||
|
# model.gradient_checkpointing_enable()
|
||||||
model = prepare_model_for_kbit_training(model)
|
model = prepare_model_for_kbit_training(model)
|
||||||
config = LoraConfig(
|
config = LoraConfig(
|
||||||
r=8,
|
r=8,
|
||||||
|
|
@ -69,9 +71,10 @@ if __name__ == "__main__":
|
||||||
gradient_accumulation_steps= 1,
|
gradient_accumulation_steps= 1,
|
||||||
warmup_steps=20,
|
warmup_steps=20,
|
||||||
max_steps=200,
|
max_steps=200,
|
||||||
learning_rate=2e-4,
|
learning_rate=2e-5,
|
||||||
save_steps=100,
|
save_steps=100,
|
||||||
fp16=True,
|
# fp16=True,
|
||||||
|
bf16=True, # bf16 is more stable in training
|
||||||
logging_steps=20,
|
logging_steps=20,
|
||||||
output_dir="outputs",
|
output_dir="outputs",
|
||||||
optim="adamw_hf", # paged_adamw_8bit is not supported yet
|
optim="adamw_hf", # paged_adamw_8bit is not supported yet
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue