LLM: fix loss error on Arc (#9550)
This commit is contained in:
parent
65121c7997
commit
4ff2ca9d0d
3 changed files with 6 additions and 8 deletions
|
|
@ -51,8 +51,7 @@ import intel_extension_for_pytorch as ipex
|
|||
from bigdl.llm.transformers import AutoModelForCausalLM
|
||||
|
||||
# import them from bigdl.llm.transformers.qlora to get a BigDL-LLM compatible Peft model
|
||||
from bigdl.llm.transformers.qlora import get_peft_model, prepare_model_for_kbit_training,\
|
||||
cast_lora_weight
|
||||
from bigdl.llm.transformers.qlora import get_peft_model, prepare_model_for_kbit_training
|
||||
|
||||
def get_int_from_env(env_keys, default):
|
||||
"""Returns the first positive env value found in the `env_keys` list or the default."""
|
||||
|
|
@ -283,9 +282,6 @@ def train(
|
|||
# model.is_parallelizable = True
|
||||
# model.model_parallel = True
|
||||
|
||||
if bf16:
|
||||
cast_lora_weight(model, torch.bfloat16)
|
||||
|
||||
trainer = transformers.Trainer(
|
||||
model=model,
|
||||
train_dataset=train_data,
|
||||
|
|
|
|||
|
|
@ -21,8 +21,7 @@ import transformers
|
|||
from transformers import LlamaTokenizer
|
||||
from peft import LoraConfig
|
||||
import intel_extension_for_pytorch as ipex
|
||||
from bigdl.llm.transformers.qlora import get_peft_model, prepare_model_for_kbit_training, \
|
||||
cast_lora_weight
|
||||
from bigdl.llm.transformers.qlora import get_peft_model, prepare_model_for_kbit_training
|
||||
from bigdl.llm.transformers import AutoModelForCausalLM
|
||||
from datasets import load_dataset
|
||||
import argparse
|
||||
|
|
@ -61,7 +60,6 @@ if __name__ == "__main__":
|
|||
task_type="CAUSAL_LM"
|
||||
)
|
||||
model = get_peft_model(model, config)
|
||||
cast_lora_weight(model, torch.bfloat16)
|
||||
|
||||
tokenizer.pad_token_id = 0
|
||||
tokenizer.padding_side = "left"
|
||||
|
|
|
|||
|
|
@ -152,6 +152,10 @@ def get_peft_model(*args, **kwargs):
|
|||
finally:
|
||||
LoraModel._create_new_module = old_create_new_module
|
||||
|
||||
if model.device.type == "xpu":
|
||||
cast_lora_weight(model, torch.bfloat16)
|
||||
torch.xpu.synchronize()
|
||||
|
||||
return model
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue