Support LoRA ChatGLM with Alpaca Dataset (#11580)
* Support LoRA ChatGLM with Alpaca Dataset * refine * fix * add 2-card alpaca
This commit is contained in:
		
							parent
							
								
									99c22745b2
								
							
						
					
					
						commit
						365adad59f
					
				
					 6 changed files with 162 additions and 20 deletions
				
			
		| 
						 | 
				
			
			@ -31,7 +31,9 @@ source /opt/intel/oneapi/setvars.sh
 | 
			
		|||
 | 
			
		||||
### 3. LoRA Fine-Tune on ChatGLM3-6B
 | 
			
		||||
 | 
			
		||||
First, download the dataset: we use `AdvertiseGen` to finetune ChatGLM3-6B in the following, and please now get it from [Google Drive](https://drive.google.com/file/d/13_vf0xRTQsyneRKdD1bZIr93vBGOczrk/view?usp=sharing) or [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/b3f119a008264b1cabd1/?dl=1), and unzip it in the current directory. Then, process the dataset with the below script:
 | 
			
		||||
First, as for the dataset, you have two options:
 | 
			
		||||
 | 
			
		||||
1. `AdvertiseGen`: please now get it from [Google Drive](https://drive.google.com/file/d/13_vf0xRTQsyneRKdD1bZIr93vBGOczrk/view?usp=sharing) or [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/b3f119a008264b1cabd1/?dl=1), and unzip it in the current directory. Then, process the dataset with the below script:
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
python process_advertise_gen_dataset.py
 | 
			
		||||
| 
						 | 
				
			
			@ -39,12 +41,20 @@ python process_advertise_gen_dataset.py
 | 
			
		|||
 | 
			
		||||
Then, './AdvertiseGen' will be converted to './AdvertiseGen_fix'. Now, we have prepared the dataset, and are going to start LoRA fine-tuning on ChatGLM3-6B.
 | 
			
		||||
 | 
			
		||||
2. `Alapca`: We also support [yahma/alpaca-cleaned](https://huggingface.co/datasets/yahma/alpaca-cleaned) that contains generated instructions and demonstrations. It does not require preprocessing, and please directy run the following script.
 | 
			
		||||
 | 
			
		||||
#### 3.1. Fine-Tune with a Single Arc Card
 | 
			
		||||
 | 
			
		||||
Start the fine-tuning by:
 | 
			
		||||
1. For `AdvertiseGen`, start the fine-tuning by:
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
bash lora_finetuning_on_chatglm3_6b_with_1_arc_card.sh
 | 
			
		||||
bash lora_finetuning_chatglm3_6b_on_advertise_gen_with_1_arc_card.sh
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
2. For `Alpaca`, start the fine-tuning by:
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
bash lora_finetuning_chatglm3_6b_on_alpaca_with_1_arc_card.sh
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
Then, you will get output are as below:
 | 
			
		||||
| 
						 | 
				
			
			@ -145,6 +155,14 @@ Training completed. Do not forget to share your model on huggingface.co/models =
 | 
			
		|||
 | 
			
		||||
Start the data-parallel fine-tuning on 2 Intel Arc XPU cards by:
 | 
			
		||||
 | 
			
		||||
1. `AdvertiseGen` dataset:
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
bash lora_finetuning_on_chatglm3_6b_with_2_arc_cards.sh
 | 
			
		||||
bash lora_finetuning_chatglm3_6b_on_advertise_gen_with_2_arc_cards.sh
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
2. `Alpaca` dataset:
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
bash lora_finetuning_chatglm3_6b_on_alpaca_with_2_arc_cards.sh
 | 
			
		||||
```
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -65,8 +65,15 @@ from transformers import (
 | 
			
		|||
)
 | 
			
		||||
from transformers import DataCollatorForSeq2Seq as _DataCollatorForSeq2Seq
 | 
			
		||||
 | 
			
		||||
from transformers import Trainer
 | 
			
		||||
from transformers import Seq2SeqTrainer as _Seq2SeqTrainer
 | 
			
		||||
 | 
			
		||||
current_dir = os.path.dirname(os.path.realpath(__file__))
 | 
			
		||||
common_util_path = os.path.join(current_dir, '..', '..')
 | 
			
		||||
import sys
 | 
			
		||||
sys.path.append(common_util_path)
 | 
			
		||||
from common.utils import get_train_val_data, Prompter
 | 
			
		||||
 | 
			
		||||
ModelType = Union[PreTrainedModel, PeftModelForCausalLM]
 | 
			
		||||
TokenizerType = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
 | 
			
		||||
app = typer.Typer(pretty_exceptions_show_locals=False)
 | 
			
		||||
| 
						 | 
				
			
			@ -247,7 +254,7 @@ def _load_datasets(
 | 
			
		|||
    return dataset_dct
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DataManager(object):
 | 
			
		||||
class AdvertiseGenDataManager(object):
 | 
			
		||||
    def __init__(self, data_dir: str, data_config: DataConfig):
 | 
			
		||||
        self._num_proc = data_config.num_proc
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -283,6 +290,52 @@ class DataManager(object):
 | 
			
		|||
            num_proc=self._num_proc,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
class AlpacaDataConfig(object):
 | 
			
		||||
    def __init__(self, tokenizer, prompter, train_on_inputs,
 | 
			
		||||
                 add_eos_token, cutoff_len, val_set_size, seed):
 | 
			
		||||
        self.tokenizer = tokenizer
 | 
			
		||||
        self.prompter = prompter
 | 
			
		||||
        self.train_on_inputs = train_on_inputs
 | 
			
		||||
        self.add_eos_token = add_eos_token
 | 
			
		||||
        self.cutoff_len = cutoff_len
 | 
			
		||||
        self.val_set_size = val_set_size
 | 
			
		||||
        self.seed = seed
 | 
			
		||||
            
 | 
			
		||||
 | 
			
		||||
class AlpacaDataManager(object):
 | 
			
		||||
    def __init__(self, data_dir: str, data_config: AlpacaDataConfig):
 | 
			
		||||
        if data_dir.endswith(".json") or data_dir.endswith(".jsonl"):
 | 
			
		||||
            data = load_dataset("json", data_files=data_dir)
 | 
			
		||||
        else:
 | 
			
		||||
            data = load_dataset(data_dir)
 | 
			
		||||
        self.train_data, self.val_data = get_train_val_data(
 | 
			
		||||
            data,
 | 
			
		||||
            data_config.tokenizer,
 | 
			
		||||
            data_config.prompter,
 | 
			
		||||
            data_config.train_on_inputs,
 | 
			
		||||
            data_config.add_eos_token,
 | 
			
		||||
            data_config.cutoff_len,
 | 
			
		||||
            data_config.val_set_size,
 | 
			
		||||
            seed=data_config.seed)
 | 
			
		||||
        self.train_data = self.train_data.remove_columns(
 | 
			
		||||
            ['output', 'input', 'instruction', 'attention_mask', 'position_ids'])
 | 
			
		||||
        self.val_data = self.val_data.remove_columns(
 | 
			
		||||
            ['output', 'input', 'instruction', 'attention_mask', 'position_ids'])
 | 
			
		||||
 | 
			
		||||
    def get_dataset(
 | 
			
		||||
            self,
 | 
			
		||||
            split: NamedSplit,
 | 
			
		||||
            process_fn: Callable[[dict[str, Any]], dict[str, Any]],
 | 
			
		||||
            batched: bool = True,
 | 
			
		||||
            remove_orig_columns: bool = True,
 | 
			
		||||
    ) -> Optional[Dataset]:
 | 
			
		||||
        if split == Split.TRAIN:
 | 
			
		||||
            return self.train_data
 | 
			
		||||
        elif split == Split.VALIDATION:
 | 
			
		||||
            return self.val_data
 | 
			
		||||
        else:
 | 
			
		||||
            return None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def print_model_size(model: PreTrainedModel):
 | 
			
		||||
    print("--> Model")
 | 
			
		||||
| 
						 | 
				
			
			@ -484,7 +537,17 @@ def main(
 | 
			
		|||
):
 | 
			
		||||
    ft_config = FinetuningConfig.from_file(config_file)
 | 
			
		||||
    tokenizer, model = load_tokenizer_and_model(model_dir, peft_config=ft_config.peft_config)
 | 
			
		||||
    data_manager = DataManager(data_dir, ft_config.data_config)
 | 
			
		||||
    if tokenizer.pad_token is None:
 | 
			
		||||
        tokenizer.pad_token = tokenizer.eos_token
 | 
			
		||||
    if 'AdvertiseGen' in data_dir:
 | 
			
		||||
        data_manager = AdvertiseGenDataManager(data_dir, ft_config.data_config)
 | 
			
		||||
    elif 'alpaca' in data_dir:
 | 
			
		||||
        data_config = AlpacaDataConfig(tokenizer=tokenizer, prompter=Prompter("alpaca"),
 | 
			
		||||
                                       train_on_inputs=True, add_eos_token=False,
 | 
			
		||||
                                       cutoff_len=256, val_set_size=2000, seed=42)
 | 
			
		||||
        data_manager = AlpacaDataManager(data_dir, data_config)
 | 
			
		||||
    else:
 | 
			
		||||
        raise NotImplementedError("Wrong dataset, currently only support AdvertiseGen and Alpaca")
 | 
			
		||||
 | 
			
		||||
    train_dataset = data_manager.get_dataset(
 | 
			
		||||
        Split.TRAIN,
 | 
			
		||||
| 
						 | 
				
			
			@ -530,38 +593,47 @@ def main(
 | 
			
		|||
    # turn model to fp32
 | 
			
		||||
    _prepare_model_for_training(model, ft_config.training_args.use_cpu)
 | 
			
		||||
 | 
			
		||||
    ft_config.training_args.generation_config.pad_token_id = (
 | 
			
		||||
        tokenizer.pad_token_id
 | 
			
		||||
    )
 | 
			
		||||
    ft_config.training_args.generation_config.eos_token_id = [
 | 
			
		||||
        tokenizer.eos_token_id,
 | 
			
		||||
        tokenizer.get_command('<|user|>'),
 | 
			
		||||
        tokenizer.get_command('<|observation|>'),
 | 
			
		||||
    ]
 | 
			
		||||
    if 'AdvertiseGen' in data_dir:
 | 
			
		||||
        ft_config.training_args.generation_config.pad_token_id = (
 | 
			
		||||
            tokenizer.pad_token_id
 | 
			
		||||
        )
 | 
			
		||||
        ft_config.training_args.generation_config.eos_token_id = [
 | 
			
		||||
            tokenizer.eos_token_id,
 | 
			
		||||
            tokenizer.get_command('<|user|>'),
 | 
			
		||||
            tokenizer.get_command('<|observation|>'),
 | 
			
		||||
        ]
 | 
			
		||||
    model.gradient_checkpointing_enable()
 | 
			
		||||
    model.enable_input_require_grads()
 | 
			
		||||
 | 
			
		||||
    use_tokenizer = True
 | 
			
		||||
    if ft_config.peft_config is not None:
 | 
			
		||||
        use_tokenizer = False if ft_config.peft_config.peft_type == "LORA" else True
 | 
			
		||||
    if 'AdvertiseGen' in data_dir:
 | 
			
		||||
        use_tokenizer = True
 | 
			
		||||
        if ft_config.peft_config is not None:
 | 
			
		||||
            use_tokenizer = False if ft_config.peft_config.peft_type == "LORA" else True
 | 
			
		||||
    else:
 | 
			
		||||
        use_tokenizer = False
 | 
			
		||||
 | 
			
		||||
    # Add below L544-L546 to enable finetuning on 2 Intel Arc XPU cards on top of oneccl and deepspeed
 | 
			
		||||
    if deepspeed_config_file != '':
 | 
			
		||||
        ft_config.training_args.ddp_backend = "ccl"
 | 
			
		||||
        ft_config.training_args.deepspeed = deepspeed_config_file
 | 
			
		||||
 | 
			
		||||
    trainer = Seq2SeqTrainer(
 | 
			
		||||
    from transformers import Trainer, TrainingArguments, DataCollatorForSeq2Seq
 | 
			
		||||
 | 
			
		||||
    BASE_TRAINER = Trainer if 'alpaca' in data_dir else Seq2SeqTrainer
 | 
			
		||||
 | 
			
		||||
    trainer = BASE_TRAINER(
 | 
			
		||||
        model=model,
 | 
			
		||||
        args=ft_config.training_args,
 | 
			
		||||
        data_collator=DataCollatorForSeq2Seq(
 | 
			
		||||
            tokenizer=tokenizer,
 | 
			
		||||
            padding='longest',
 | 
			
		||||
            return_tensors='pt',
 | 
			
		||||
            padding=True if 'alpaca' in data_dir else 'longest',
 | 
			
		||||
            pad_to_multiple_of=8 if 'alpaca' in data_dir else None,
 | 
			
		||||
        ),
 | 
			
		||||
        train_dataset=train_dataset,
 | 
			
		||||
        eval_dataset=val_dataset.select(list(range(50))),
 | 
			
		||||
        tokenizer=tokenizer if use_tokenizer else None,  # LORA does not need tokenizer
 | 
			
		||||
        compute_metrics=functools.partial(compute_metrics, tokenizer=tokenizer),
 | 
			
		||||
        compute_metrics=functools.partial(compute_metrics, tokenizer=tokenizer) if 'AdvertiseGen' in data_dir else None,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    if auto_resume_from_checkpoint.upper() == "" or auto_resume_from_checkpoint is None:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -0,0 +1,23 @@
 | 
			
		|||
#
 | 
			
		||||
# Copyright 2016 The BigDL Authors.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
#
 | 
			
		||||
 | 
			
		||||
export BIGDL_CHECK_DUPLICATE_IMPORT=0
 | 
			
		||||
 | 
			
		||||
# You can also set the remote model repository to a local model path
 | 
			
		||||
python lora_finetune_chatglm.py \
 | 
			
		||||
       yahma/alpaca-cleaned \
 | 
			
		||||
       THUDM/chatglm3-6b  \
 | 
			
		||||
       ./lora_config.yaml
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,29 @@
 | 
			
		|||
#
 | 
			
		||||
# Copyright 2016 The BigDL Authors.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
#
 | 
			
		||||
 | 
			
		||||
export MASTER_ADDR=127.0.0.1
 | 
			
		||||
export OMP_NUM_THREADS=6
 | 
			
		||||
export FI_PROVIDER=tcp
 | 
			
		||||
export CCL_ATL_TRANSPORT=ofi
 | 
			
		||||
export BIGDL_CHECK_DUPLICATE_IMPORT=0
 | 
			
		||||
 | 
			
		||||
# You can also set the remote model repository to a local model path
 | 
			
		||||
mpirun -n 2 \
 | 
			
		||||
    python lora_finetune_chatglm.py \
 | 
			
		||||
        yahma/alpaca-cleaned  \
 | 
			
		||||
        THUDM/chatglm3-6b  \
 | 
			
		||||
        ./lora_config.yaml \
 | 
			
		||||
	./deepspeed_config.json
 | 
			
		||||
		Loading…
	
		Reference in a new issue