Support LoRA ChatGLM with Alpaca Dataset (#11580)
* Support LoRA ChatGLM with Alpaca Dataset * refine * fix * add 2-card alpaca
This commit is contained in:
parent
99c22745b2
commit
365adad59f
6 changed files with 162 additions and 20 deletions
|
|
@ -31,7 +31,9 @@ source /opt/intel/oneapi/setvars.sh
|
|||
|
||||
### 3. LoRA Fine-Tune on ChatGLM3-6B
|
||||
|
||||
First, download the dataset: we use `AdvertiseGen` to finetune ChatGLM3-6B in the following, and please now get it from [Google Drive](https://drive.google.com/file/d/13_vf0xRTQsyneRKdD1bZIr93vBGOczrk/view?usp=sharing) or [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/b3f119a008264b1cabd1/?dl=1), and unzip it in the current directory. Then, process the dataset with the below script:
|
||||
First, as for the dataset, you have two options:
|
||||
|
||||
1. `AdvertiseGen`: please now get it from [Google Drive](https://drive.google.com/file/d/13_vf0xRTQsyneRKdD1bZIr93vBGOczrk/view?usp=sharing) or [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/b3f119a008264b1cabd1/?dl=1), and unzip it in the current directory. Then, process the dataset with the below script:
|
||||
|
||||
```bash
|
||||
python process_advertise_gen_dataset.py
|
||||
|
|
@ -39,12 +41,20 @@ python process_advertise_gen_dataset.py
|
|||
|
||||
Then, './AdvertiseGen' will be converted to './AdvertiseGen_fix'. Now, we have prepared the dataset, and are going to start LoRA fine-tuning on ChatGLM3-6B.
|
||||
|
||||
2. `Alapca`: We also support [yahma/alpaca-cleaned](https://huggingface.co/datasets/yahma/alpaca-cleaned) that contains generated instructions and demonstrations. It does not require preprocessing, and please directy run the following script.
|
||||
|
||||
#### 3.1. Fine-Tune with a Single Arc Card
|
||||
|
||||
Start the fine-tuning by:
|
||||
1. For `AdvertiseGen`, start the fine-tuning by:
|
||||
|
||||
```bash
|
||||
bash lora_finetuning_on_chatglm3_6b_with_1_arc_card.sh
|
||||
bash lora_finetuning_chatglm3_6b_on_advertise_gen_with_1_arc_card.sh
|
||||
```
|
||||
|
||||
2. For `Alpaca`, start the fine-tuning by:
|
||||
|
||||
```bash
|
||||
bash lora_finetuning_chatglm3_6b_on_alpaca_with_1_arc_card.sh
|
||||
```
|
||||
|
||||
Then, you will get output are as below:
|
||||
|
|
@ -145,6 +155,14 @@ Training completed. Do not forget to share your model on huggingface.co/models =
|
|||
|
||||
Start the data-parallel fine-tuning on 2 Intel Arc XPU cards by:
|
||||
|
||||
1. `AdvertiseGen` dataset:
|
||||
|
||||
```bash
|
||||
bash lora_finetuning_on_chatglm3_6b_with_2_arc_cards.sh
|
||||
bash lora_finetuning_chatglm3_6b_on_advertise_gen_with_2_arc_cards.sh
|
||||
```
|
||||
|
||||
2. `Alpaca` dataset:
|
||||
|
||||
```bash
|
||||
bash lora_finetuning_chatglm3_6b_on_alpaca_with_2_arc_cards.sh
|
||||
```
|
||||
|
|
|
|||
|
|
@ -65,8 +65,15 @@ from transformers import (
|
|||
)
|
||||
from transformers import DataCollatorForSeq2Seq as _DataCollatorForSeq2Seq
|
||||
|
||||
from transformers import Trainer
|
||||
from transformers import Seq2SeqTrainer as _Seq2SeqTrainer
|
||||
|
||||
current_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
common_util_path = os.path.join(current_dir, '..', '..')
|
||||
import sys
|
||||
sys.path.append(common_util_path)
|
||||
from common.utils import get_train_val_data, Prompter
|
||||
|
||||
ModelType = Union[PreTrainedModel, PeftModelForCausalLM]
|
||||
TokenizerType = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
|
||||
app = typer.Typer(pretty_exceptions_show_locals=False)
|
||||
|
|
@ -247,7 +254,7 @@ def _load_datasets(
|
|||
return dataset_dct
|
||||
|
||||
|
||||
class DataManager(object):
|
||||
class AdvertiseGenDataManager(object):
|
||||
def __init__(self, data_dir: str, data_config: DataConfig):
|
||||
self._num_proc = data_config.num_proc
|
||||
|
||||
|
|
@ -283,6 +290,52 @@ class DataManager(object):
|
|||
num_proc=self._num_proc,
|
||||
)
|
||||
|
||||
class AlpacaDataConfig(object):
|
||||
def __init__(self, tokenizer, prompter, train_on_inputs,
|
||||
add_eos_token, cutoff_len, val_set_size, seed):
|
||||
self.tokenizer = tokenizer
|
||||
self.prompter = prompter
|
||||
self.train_on_inputs = train_on_inputs
|
||||
self.add_eos_token = add_eos_token
|
||||
self.cutoff_len = cutoff_len
|
||||
self.val_set_size = val_set_size
|
||||
self.seed = seed
|
||||
|
||||
|
||||
class AlpacaDataManager(object):
|
||||
def __init__(self, data_dir: str, data_config: AlpacaDataConfig):
|
||||
if data_dir.endswith(".json") or data_dir.endswith(".jsonl"):
|
||||
data = load_dataset("json", data_files=data_dir)
|
||||
else:
|
||||
data = load_dataset(data_dir)
|
||||
self.train_data, self.val_data = get_train_val_data(
|
||||
data,
|
||||
data_config.tokenizer,
|
||||
data_config.prompter,
|
||||
data_config.train_on_inputs,
|
||||
data_config.add_eos_token,
|
||||
data_config.cutoff_len,
|
||||
data_config.val_set_size,
|
||||
seed=data_config.seed)
|
||||
self.train_data = self.train_data.remove_columns(
|
||||
['output', 'input', 'instruction', 'attention_mask', 'position_ids'])
|
||||
self.val_data = self.val_data.remove_columns(
|
||||
['output', 'input', 'instruction', 'attention_mask', 'position_ids'])
|
||||
|
||||
def get_dataset(
|
||||
self,
|
||||
split: NamedSplit,
|
||||
process_fn: Callable[[dict[str, Any]], dict[str, Any]],
|
||||
batched: bool = True,
|
||||
remove_orig_columns: bool = True,
|
||||
) -> Optional[Dataset]:
|
||||
if split == Split.TRAIN:
|
||||
return self.train_data
|
||||
elif split == Split.VALIDATION:
|
||||
return self.val_data
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def print_model_size(model: PreTrainedModel):
|
||||
print("--> Model")
|
||||
|
|
@ -484,7 +537,17 @@ def main(
|
|||
):
|
||||
ft_config = FinetuningConfig.from_file(config_file)
|
||||
tokenizer, model = load_tokenizer_and_model(model_dir, peft_config=ft_config.peft_config)
|
||||
data_manager = DataManager(data_dir, ft_config.data_config)
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
if 'AdvertiseGen' in data_dir:
|
||||
data_manager = AdvertiseGenDataManager(data_dir, ft_config.data_config)
|
||||
elif 'alpaca' in data_dir:
|
||||
data_config = AlpacaDataConfig(tokenizer=tokenizer, prompter=Prompter("alpaca"),
|
||||
train_on_inputs=True, add_eos_token=False,
|
||||
cutoff_len=256, val_set_size=2000, seed=42)
|
||||
data_manager = AlpacaDataManager(data_dir, data_config)
|
||||
else:
|
||||
raise NotImplementedError("Wrong dataset, currently only support AdvertiseGen and Alpaca")
|
||||
|
||||
train_dataset = data_manager.get_dataset(
|
||||
Split.TRAIN,
|
||||
|
|
@ -530,38 +593,47 @@ def main(
|
|||
# turn model to fp32
|
||||
_prepare_model_for_training(model, ft_config.training_args.use_cpu)
|
||||
|
||||
ft_config.training_args.generation_config.pad_token_id = (
|
||||
tokenizer.pad_token_id
|
||||
)
|
||||
ft_config.training_args.generation_config.eos_token_id = [
|
||||
tokenizer.eos_token_id,
|
||||
tokenizer.get_command('<|user|>'),
|
||||
tokenizer.get_command('<|observation|>'),
|
||||
]
|
||||
if 'AdvertiseGen' in data_dir:
|
||||
ft_config.training_args.generation_config.pad_token_id = (
|
||||
tokenizer.pad_token_id
|
||||
)
|
||||
ft_config.training_args.generation_config.eos_token_id = [
|
||||
tokenizer.eos_token_id,
|
||||
tokenizer.get_command('<|user|>'),
|
||||
tokenizer.get_command('<|observation|>'),
|
||||
]
|
||||
model.gradient_checkpointing_enable()
|
||||
model.enable_input_require_grads()
|
||||
|
||||
use_tokenizer = True
|
||||
if ft_config.peft_config is not None:
|
||||
use_tokenizer = False if ft_config.peft_config.peft_type == "LORA" else True
|
||||
if 'AdvertiseGen' in data_dir:
|
||||
use_tokenizer = True
|
||||
if ft_config.peft_config is not None:
|
||||
use_tokenizer = False if ft_config.peft_config.peft_type == "LORA" else True
|
||||
else:
|
||||
use_tokenizer = False
|
||||
|
||||
# Add below L544-L546 to enable finetuning on 2 Intel Arc XPU cards on top of oneccl and deepspeed
|
||||
if deepspeed_config_file != '':
|
||||
ft_config.training_args.ddp_backend = "ccl"
|
||||
ft_config.training_args.deepspeed = deepspeed_config_file
|
||||
|
||||
trainer = Seq2SeqTrainer(
|
||||
from transformers import Trainer, TrainingArguments, DataCollatorForSeq2Seq
|
||||
|
||||
BASE_TRAINER = Trainer if 'alpaca' in data_dir else Seq2SeqTrainer
|
||||
|
||||
trainer = BASE_TRAINER(
|
||||
model=model,
|
||||
args=ft_config.training_args,
|
||||
data_collator=DataCollatorForSeq2Seq(
|
||||
tokenizer=tokenizer,
|
||||
padding='longest',
|
||||
return_tensors='pt',
|
||||
padding=True if 'alpaca' in data_dir else 'longest',
|
||||
pad_to_multiple_of=8 if 'alpaca' in data_dir else None,
|
||||
),
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=val_dataset.select(list(range(50))),
|
||||
tokenizer=tokenizer if use_tokenizer else None, # LORA does not need tokenizer
|
||||
compute_metrics=functools.partial(compute_metrics, tokenizer=tokenizer),
|
||||
compute_metrics=functools.partial(compute_metrics, tokenizer=tokenizer) if 'AdvertiseGen' in data_dir else None,
|
||||
)
|
||||
|
||||
if auto_resume_from_checkpoint.upper() == "" or auto_resume_from_checkpoint is None:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,23 @@
|
|||
#
|
||||
# Copyright 2016 The BigDL Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
export BIGDL_CHECK_DUPLICATE_IMPORT=0
|
||||
|
||||
# You can also set the remote model repository to a local model path
|
||||
python lora_finetune_chatglm.py \
|
||||
yahma/alpaca-cleaned \
|
||||
THUDM/chatglm3-6b \
|
||||
./lora_config.yaml
|
||||
|
|
@ -0,0 +1,29 @@
|
|||
#
|
||||
# Copyright 2016 The BigDL Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
export MASTER_ADDR=127.0.0.1
|
||||
export OMP_NUM_THREADS=6
|
||||
export FI_PROVIDER=tcp
|
||||
export CCL_ATL_TRANSPORT=ofi
|
||||
export BIGDL_CHECK_DUPLICATE_IMPORT=0
|
||||
|
||||
# You can also set the remote model repository to a local model path
|
||||
mpirun -n 2 \
|
||||
python lora_finetune_chatglm.py \
|
||||
yahma/alpaca-cleaned \
|
||||
THUDM/chatglm3-6b \
|
||||
./lora_config.yaml \
|
||||
./deepspeed_config.json
|
||||
Loading…
Reference in a new issue