fix non-string deepseed config path bug (#11476)

* fix non-string deepseed config path bug

* Update lora_finetune_chatglm.py
This commit is contained in:
Heyang Sun 2024-07-01 15:53:50 +08:00 committed by GitHub
parent 48ad482d3d
commit 913e750b01
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -46,6 +46,7 @@ import typer
from datasets import Dataset, DatasetDict, NamedSplit, Split, load_dataset from datasets import Dataset, DatasetDict, NamedSplit, Split, load_dataset
from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu
from peft import ( from peft import (
get_peft_model,
PeftConfig, PeftConfig,
PeftModelForCausalLM, PeftModelForCausalLM,
get_peft_config get_peft_config
@ -53,6 +54,7 @@ from peft import (
from rouge_chinese import Rouge from rouge_chinese import Rouge
from torch import nn from torch import nn
from transformers import ( from transformers import (
AutoModelForCausalLM,
AutoTokenizer, AutoTokenizer,
EvalPrediction, EvalPrediction,
GenerationConfig, GenerationConfig,
@ -471,7 +473,10 @@ def main(
], ],
config_file: Annotated[str, typer.Argument(help='')], config_file: Annotated[str, typer.Argument(help='')],
# Add below L474, which is path of deepspeed config file to enable finetuning on 2 Intel Arc XPU cards # Add below L474, which is path of deepspeed config file to enable finetuning on 2 Intel Arc XPU cards
deepspeed_config_file: Annotated[str, typer.Argument(default='', help='if specified, will apply data parallel')] deepspeed_config_file: str = typer.Argument(
default='',
help='if specified, will apply data parallel'
),
auto_resume_from_checkpoint: str = typer.Argument( auto_resume_from_checkpoint: str = typer.Argument(
default='', default='',
help='If entered as yes, automatically use the latest save checkpoint. If it is a numerical example 12 15, use the corresponding save checkpoint. If the input is no, restart training' help='If entered as yes, automatically use the latest save checkpoint. If it is a numerical example 12 15, use the corresponding save checkpoint. If the input is no, restart training'
@ -541,7 +546,7 @@ def main(
use_tokenizer = False if ft_config.peft_config.peft_type == "LORA" else True use_tokenizer = False if ft_config.peft_config.peft_type == "LORA" else True
# Add below L544-L546 to enable finetuning on 2 Intel Arc XPU cards on top of oneccl and deepspeed # Add below L544-L546 to enable finetuning on 2 Intel Arc XPU cards on top of oneccl and deepspeed
if deepspeed_config_file is not '': if deepspeed_config_file != '':
ft_config.training_args.ddp_backend = "ccl" ft_config.training_args.ddp_backend = "ccl"
ft_config.training_args.deepspeed = deepspeed_config_file ft_config.training_args.deepspeed = deepspeed_config_file