ChatGLM3-6B LoRA Fine-tuning Demo (#11450)
* ChatGLM3-6B LoRA Fine-tuning Demo * refine * refine * add 2-card deepspeed * refine format * add mpi4py and deepspeed install
This commit is contained in:
parent
e000ac90c4
commit
07362ffffc
8 changed files with 927 additions and 1 deletions
|
|
@ -0,0 +1,150 @@
|
|||
# LoRA Fine-Tuning on ChatGLM3-6B with IPEX-LLM
|
||||
|
||||
This example ports [ChatGLM3-6B lora_finetune](https://github.com/THUDM/ChatGLM3/blob/main/finetune_demo/lora_finetune.ipynb) demo to IPEX-LLM on [Intel Arc GPU](../../README.md).
|
||||
|
||||
### 1. Install
|
||||
|
||||
```bash
|
||||
conda create -n llm python=3.11
|
||||
conda activate llm
|
||||
pip install "jieba>=0.42.1"
|
||||
pip install "ruamel_yaml>=0.18.6"
|
||||
pip install "rouge_chinese>=1.0.3"
|
||||
pip install "jupyter>=1.0.0"
|
||||
pip install "datasets>=2.18.0"
|
||||
pip install "peft>=0.10.0"
|
||||
pip install typer
|
||||
pip install sentencepiece
|
||||
pip install nltk
|
||||
pip install "numpy<2.0.0"
|
||||
pip install "deepspeed==0.13.1"
|
||||
pip install "mpi4py>=3.1.5"
|
||||
# below command will install intel_extension_for_pytorch==2.1.10+xpu as default
|
||||
pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
|
||||
pip install oneccl_bind_pt==2.1.100 --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
|
||||
```
|
||||
|
||||
### 2. Configures OneAPI Environment Variables
|
||||
```bash
|
||||
source /opt/intel/oneapi/setvars.sh
|
||||
```
|
||||
|
||||
### 3. LoRA Fine-Tune on ChatGLM3-6B
|
||||
|
||||
First, download the dataset: we use `AdvertiseGen` to finetune ChatGLM3-6B in the following, and please now get it from [Google Drive](https://drive.google.com/file/d/13_vf0xRTQsyneRKdD1bZIr93vBGOczrk/view?usp=sharing) or [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/b3f119a008264b1cabd1/?dl=1), and unzip it in the current directory. Then, process the dataset with the below script:
|
||||
|
||||
```bash
|
||||
python process_advertise_gen_dataset.py
|
||||
```
|
||||
|
||||
Then, './AdvertiseGen' will be converted to './AdvertiseGen_fix'. Now, we have prepared the dataset, and are going to start LoRA fine-tuning on ChatGLM3-6B.
|
||||
|
||||
#### 3.1. Fine-Tune with a Single Arc Card
|
||||
|
||||
Start the fine-tuning by:
|
||||
|
||||
```bash
|
||||
bash lora_finetuning_on_chatglm3_6b_with_1_arc_card.sh
|
||||
```
|
||||
|
||||
Then, you will get output are as below:
|
||||
|
||||
```bash
|
||||
2024-06-27 13:47:02,680 - root - INFO - intel_extension_for_pytorch auto imported
|
||||
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████| 7/7 [00:01<00:00, 6.47it/s]
|
||||
2024-06-27 13:47:03,794 - ipex_llm.transformers.utils - INFO - Converting the current model to bf16 format......
|
||||
[2024-06-27 13:47:04,105] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to xpu (auto detect)
|
||||
trainable params: 487,424 || all params: 6,244,071,424 || trainable%: 0.0078
|
||||
PeftModelForCausalLM(
|
||||
(base_model): LoraModel(
|
||||
(model): ChatGLMForConditionalGeneration(
|
||||
(transformer): ChatGLMModel(
|
||||
(embedding): Embedding(
|
||||
(word_embeddings): Embedding(65024, 4096)
|
||||
)
|
||||
(rotary_pos_emb): RotaryEmbedding()
|
||||
(encoder): GLMTransformer(
|
||||
(layers): ModuleList(
|
||||
(0-27): 28 x GLMBlock(
|
||||
(input_layernorm): RMSNorm()
|
||||
(self_attention): SelfAttention(
|
||||
(query_key_value): LoraLowBitLinear(
|
||||
(base_layer): BF16Linear(in_features=4096, out_features=4608, bias=True)
|
||||
(lora_dropout): ModuleDict(
|
||||
(default): Dropout(p=0.1, inplace=False)
|
||||
)
|
||||
(lora_A): ModuleDict(
|
||||
(default): Linear(in_features=4096, out_features=2, bias=False)
|
||||
)
|
||||
(lora_B): ModuleDict(
|
||||
(default): Linear(in_features=2, out_features=4608, bias=False)
|
||||
)
|
||||
(lora_embedding_A): ParameterDict()
|
||||
(lora_embedding_B): ParameterDict()
|
||||
(qa_pool): Identity()
|
||||
)
|
||||
(core_attention): CoreAttention(
|
||||
(attention_dropout): Dropout(p=0.0, inplace=False)
|
||||
)
|
||||
(dense): BF16Linear(in_features=4096, out_features=4096, bias=False)
|
||||
)
|
||||
(post_attention_layernorm): RMSNorm()
|
||||
(mlp): MLP(
|
||||
(dense_h_to_4h): BF16Linear(in_features=4096, out_features=27392, bias=False)
|
||||
(dense_4h_to_h): BF16Linear(in_features=13696, out_features=4096, bias=False)
|
||||
)
|
||||
)
|
||||
)
|
||||
(final_layernorm): RMSNorm()
|
||||
)
|
||||
(output_layer): BF16Linear(in_features=4096, out_features=65024, bias=False)
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
--> Model
|
||||
|
||||
--> model has 0.487424M params
|
||||
|
||||
train_dataset: Dataset({
|
||||
features: ['input_ids', 'labels'],
|
||||
num_rows: 114599
|
||||
})
|
||||
val_dataset: Dataset({
|
||||
features: ['input_ids', 'output_ids'],
|
||||
num_rows: 1070
|
||||
})
|
||||
test_dataset: Dataset({
|
||||
features: ['input_ids', 'output_ids'],
|
||||
num_rows: 1070
|
||||
})
|
||||
--> Sanity check
|
||||
'[gMASK]': 64790 -> -100
|
||||
'sop': 64792 -> -100
|
||||
'<|user|>': 64795 -> -100
|
||||
'': 30910 -> -100
|
||||
'\n': 13 -> -100
|
||||
......
|
||||
|
||||
# Here it takes time to finish the whole fine-tuning
|
||||
|
||||
......
|
||||
|
||||
Training completed. Do not forget to share your model on huggingface.co/models =)
|
||||
|
||||
|
||||
{'train_runtime': xxxx.xxxx, 'train_samples_per_second': x.xxx, 'train_steps_per_second': x.xxx, 'train_loss': xx.xx, 'epoch': x.xx}
|
||||
100%|████████████████████████████████████████████████████████████████████████████████████████████| 3000/3000 [xx:xx<00:00, x.xxit/s]
|
||||
***** Running Prediction *****
|
||||
Num examples = 1070
|
||||
Batch size = 4
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████| 268/268 [xx:xx<00:00, x.xxs/it]
|
||||
```
|
||||
|
||||
#### 3.2. Fine-Tune with 2 Arc Cards
|
||||
|
||||
Start the data-parallel fine-tuning on 2 Intel Arc XPU cards by:
|
||||
|
||||
```bash
|
||||
bash lora_finetuning_on_chatglm3_6b_with_2_arc_cards.sh
|
||||
```
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
{
|
||||
"zero_optimization": {
|
||||
"stage": 2,
|
||||
"offload_optimizer": {
|
||||
"device": "cpu"
|
||||
},
|
||||
"contiguous_gradients": true,
|
||||
"overlap_comm": true
|
||||
},
|
||||
"bf16": {
|
||||
"enabled": true
|
||||
},
|
||||
"train_micro_batch_size_per_gpu": "auto",
|
||||
"gradient_accumulation_steps": "auto"
|
||||
}
|
||||
|
|
@ -0,0 +1,47 @@
|
|||
# This is ported from https://github.com/THUDM/ChatGLM3/blob/main/finetune_demo/configs/lora.yaml
|
||||
data_config:
|
||||
train_file: train.json
|
||||
val_file: dev.json
|
||||
test_file: dev.json
|
||||
num_proc: 16
|
||||
max_input_length: 128
|
||||
max_output_length: 128
|
||||
training_args:
|
||||
# see `transformers.Seq2SeqTrainingArguments`
|
||||
output_dir: ./output
|
||||
max_steps: 3000
|
||||
# needed to be fit for the dataset
|
||||
learning_rate: 5e-5
|
||||
# settings for data loading
|
||||
per_device_train_batch_size: 1
|
||||
dataloader_num_workers: 16
|
||||
remove_unused_columns: false
|
||||
# settings for saving checkpoints
|
||||
save_strategy: steps
|
||||
save_steps: 500
|
||||
# settings for logging
|
||||
log_level: info
|
||||
logging_strategy: steps
|
||||
logging_steps: 10
|
||||
# settings for evaluation
|
||||
per_device_eval_batch_size: 4
|
||||
evaluation_strategy: steps
|
||||
eval_steps: 1000
|
||||
# settings for optimizer
|
||||
# adam_epsilon: 1e-6
|
||||
# uncomment the following line to detect nan or inf values
|
||||
# debug: underflow_overflow
|
||||
predict_with_generate: true
|
||||
# see `transformers.GenerationConfig`
|
||||
generation_config:
|
||||
max_new_tokens: 128
|
||||
# set your absolute deepspeed path here
|
||||
#deepspeed: ds_zero_2.json
|
||||
# set to true if train with cpu.
|
||||
use_cpu: false
|
||||
peft_config:
|
||||
peft_type: LORA
|
||||
task_type: CAUSAL_LM
|
||||
r: 2
|
||||
lora_alpha: 8
|
||||
lora_dropout: 0.1
|
||||
|
|
@ -0,0 +1,603 @@
|
|||
#
|
||||
# Copyright 2016 The BigDL Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Below 2 lines different from the original example, where transformers are patched with IPEX LLM
|
||||
from ipex_llm import llm_patch
|
||||
llm_patch(train=True)
|
||||
|
||||
# This below example is ported from https://github.com/THUDM/ChatGLM3/blob/main/finetune_demo/finetune_hf.py
|
||||
# L417, L474 and L544-L546 are modified to enable the example on Intel Arc
|
||||
import os
|
||||
import jieba
|
||||
import dataclasses as dc
|
||||
import functools
|
||||
from collections.abc import Callable, Mapping, Sequence
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Any, Optional, Union
|
||||
import numpy as np
|
||||
import ruamel.yaml as yaml
|
||||
import torch
|
||||
import typer
|
||||
from datasets import Dataset, DatasetDict, NamedSplit, Split, load_dataset
|
||||
from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu
|
||||
from peft import (
|
||||
PeftConfig,
|
||||
PeftModelForCausalLM,
|
||||
get_peft_config
|
||||
)
|
||||
from rouge_chinese import Rouge
|
||||
from torch import nn
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
EvalPrediction,
|
||||
GenerationConfig,
|
||||
PreTrainedModel,
|
||||
PreTrainedTokenizer,
|
||||
PreTrainedTokenizerFast,
|
||||
Seq2SeqTrainingArguments, AutoConfig,
|
||||
)
|
||||
from transformers import DataCollatorForSeq2Seq as _DataCollatorForSeq2Seq
|
||||
|
||||
from transformers import Seq2SeqTrainer as _Seq2SeqTrainer
|
||||
|
||||
ModelType = Union[PreTrainedModel, PeftModelForCausalLM]
|
||||
TokenizerType = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
|
||||
app = typer.Typer(pretty_exceptions_show_locals=False)
|
||||
|
||||
|
||||
class DataCollatorForSeq2Seq(_DataCollatorForSeq2Seq):
|
||||
def __call__(self, features, return_tensors=None):
|
||||
output_ids = (
|
||||
[feature['output_ids'] for feature in features]
|
||||
if 'output_ids' in features[0].keys()
|
||||
else None
|
||||
)
|
||||
if output_ids is not None:
|
||||
max_output_length = max(len(out) for out in output_ids)
|
||||
if self.pad_to_multiple_of is not None:
|
||||
max_output_length = (
|
||||
(
|
||||
max_output_length + self.pad_to_multiple_of - 1) //
|
||||
self.pad_to_multiple_of * self.pad_to_multiple_of
|
||||
)
|
||||
for feature in features:
|
||||
remainder = [self.tokenizer.pad_token_id] * (
|
||||
max_output_length - len(feature['output_ids'])
|
||||
)
|
||||
if isinstance(feature['output_ids'], list):
|
||||
feature['output_ids'] = feature['output_ids'] + remainder
|
||||
else:
|
||||
feature['output_ids'] = np.concatenate(
|
||||
[feature['output_ids'], remainder]
|
||||
).astype(np.int64)
|
||||
return super().__call__(features, return_tensors)
|
||||
|
||||
|
||||
class Seq2SeqTrainer(_Seq2SeqTrainer):
|
||||
def prediction_step(
|
||||
self,
|
||||
model: nn.Module,
|
||||
inputs: dict[str, Any],
|
||||
prediction_loss_only: bool,
|
||||
ignore_keys=None,
|
||||
**gen_kwargs,
|
||||
) -> tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
if self.args.predict_with_generate:
|
||||
output_ids = inputs.pop('output_ids')
|
||||
input_ids = inputs['input_ids']
|
||||
loss, generated_tokens, labels = super().prediction_step(
|
||||
model, inputs, prediction_loss_only, ignore_keys, **gen_kwargs
|
||||
)
|
||||
generated_tokens = generated_tokens[:, input_ids.size()[1]:]
|
||||
if self.args.predict_with_generate:
|
||||
labels = output_ids
|
||||
return loss, generated_tokens, labels
|
||||
|
||||
|
||||
def _resolve_path(path: Union[str, Path]) -> Path:
|
||||
return Path(path).expanduser().resolve()
|
||||
|
||||
|
||||
def _sanity_check(
|
||||
input_ids: Sequence[int],
|
||||
output_ids: Sequence[int],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
):
|
||||
print('--> Sanity check')
|
||||
for in_id, out_id in zip(input_ids, output_ids):
|
||||
if in_id == 0:
|
||||
continue
|
||||
if in_id in tokenizer.tokenizer.index_special_tokens:
|
||||
in_text = tokenizer.tokenizer.index_special_tokens[in_id]
|
||||
else:
|
||||
in_text = tokenizer.decode([in_id])
|
||||
print(f'{repr(in_text):>20}: {in_id} -> {out_id}')
|
||||
|
||||
|
||||
@functools.cache
|
||||
def _get_yaml_parser() -> yaml.YAML:
|
||||
parser = yaml.YAML(typ='safe', pure=True)
|
||||
parser.indent(mapping=2, offset=2, sequence=4)
|
||||
parser.default_flow_style = False
|
||||
return parser
|
||||
|
||||
|
||||
@dc.dataclass
|
||||
class DataConfig(object):
|
||||
train_file: str
|
||||
val_file: Optional[str] = None
|
||||
test_file: Optional[str] = None
|
||||
|
||||
num_proc: Optional[int] = None
|
||||
|
||||
@property
|
||||
def data_format(self) -> str:
|
||||
return Path(self.train_file).suffix
|
||||
|
||||
@property
|
||||
def data_files(self) -> dict[NamedSplit, str]:
|
||||
return {
|
||||
split: data_file
|
||||
for split, data_file in zip(
|
||||
[Split.TRAIN, Split.VALIDATION, Split.TEST],
|
||||
[self.train_file, self.val_file, self.test_file],
|
||||
)
|
||||
if data_file is not None
|
||||
}
|
||||
|
||||
|
||||
@dc.dataclass
|
||||
class FinetuningConfig(object):
|
||||
data_config: DataConfig
|
||||
|
||||
max_input_length: int
|
||||
max_output_length: int
|
||||
|
||||
training_args: Seq2SeqTrainingArguments = dc.field(
|
||||
default_factory=lambda: Seq2SeqTrainingArguments(output_dir='./output')
|
||||
)
|
||||
peft_config: Optional[PeftConfig] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if not self.training_args.do_eval or self.data_config.val_file is None:
|
||||
# skips the evaluation stage when `do_eval` or `eval_file` is not provided
|
||||
self.training_args.do_eval = False
|
||||
self.training_args.evaluation_strategy = 'no'
|
||||
self.data_config.val_file = None
|
||||
else:
|
||||
self.training_args.per_device_eval_batch_size = (
|
||||
self.training_args.per_device_eval_batch_size
|
||||
or self.training_args.per_device_train_batch_size
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, **kwargs) -> 'FinetuningConfig':
|
||||
training_args = kwargs.get('training_args', None)
|
||||
if training_args is not None and not isinstance(
|
||||
training_args, Seq2SeqTrainingArguments
|
||||
):
|
||||
gen_config = training_args.get('generation_config')
|
||||
# TODO: a bit hacky
|
||||
if not isinstance(gen_config, GenerationConfig):
|
||||
training_args['generation_config'] = GenerationConfig(
|
||||
**gen_config
|
||||
)
|
||||
kwargs['training_args'] = Seq2SeqTrainingArguments(**training_args)
|
||||
|
||||
data_config = kwargs.get('data_config')
|
||||
if not isinstance(data_config, DataConfig):
|
||||
kwargs['data_config'] = DataConfig(**data_config)
|
||||
|
||||
peft_config = kwargs.get('peft_config', None)
|
||||
if peft_config is not None and not isinstance(peft_config, PeftConfig):
|
||||
kwargs['peft_config'] = get_peft_config(peft_config)
|
||||
return cls(**kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_file(cls, path: Union[str, Path]) -> 'FinetuningConfig':
|
||||
path = _resolve_path(path)
|
||||
kwargs = _get_yaml_parser().load(path)
|
||||
return cls.from_dict(**kwargs)
|
||||
|
||||
|
||||
def _load_datasets(
|
||||
data_dir: Path,
|
||||
data_format: str,
|
||||
data_files: dict[NamedSplit, str],
|
||||
num_proc: Optional[int],
|
||||
) -> DatasetDict:
|
||||
if data_format in ('.csv', '.json', '.jsonl'):
|
||||
dataset_dct = load_dataset(
|
||||
data_format[1:],
|
||||
data_dir=data_dir,
|
||||
data_files=data_files,
|
||||
num_proc=num_proc,
|
||||
)
|
||||
else:
|
||||
err_msg = f"Cannot load dataset in the '{data_format}' format."
|
||||
raise NotImplementedError(err_msg)
|
||||
|
||||
return dataset_dct
|
||||
|
||||
|
||||
class DataManager(object):
|
||||
def __init__(self, data_dir: str, data_config: DataConfig):
|
||||
self._num_proc = data_config.num_proc
|
||||
|
||||
self._dataset_dct = _load_datasets(
|
||||
_resolve_path(data_dir),
|
||||
data_config.data_format,
|
||||
data_config.data_files,
|
||||
self._num_proc,
|
||||
)
|
||||
|
||||
def _get_dataset(self, split: NamedSplit) -> Optional[Dataset]:
|
||||
return self._dataset_dct.get(split, None)
|
||||
|
||||
def get_dataset(
|
||||
self,
|
||||
split: NamedSplit,
|
||||
process_fn: Callable[[dict[str, Any]], dict[str, Any]],
|
||||
batched: bool = True,
|
||||
remove_orig_columns: bool = True,
|
||||
) -> Optional[Dataset]:
|
||||
orig_dataset = self._get_dataset(split)
|
||||
if orig_dataset is None:
|
||||
return
|
||||
|
||||
if remove_orig_columns:
|
||||
remove_columns = orig_dataset.column_names
|
||||
else:
|
||||
remove_columns = None
|
||||
return orig_dataset.map(
|
||||
process_fn,
|
||||
batched=batched,
|
||||
remove_columns=remove_columns,
|
||||
num_proc=self._num_proc,
|
||||
)
|
||||
|
||||
|
||||
def print_model_size(model: PreTrainedModel):
|
||||
print("--> Model")
|
||||
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
print(f"\n--> model has {total_params / 1e6}M params\n")
|
||||
|
||||
|
||||
def process_batch(
|
||||
batch: Mapping[str, Sequence],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
max_input_length: int,
|
||||
max_output_length: int,
|
||||
) -> dict[str, list]:
|
||||
batched_tools = batch.get('tools', None)
|
||||
batched_conv = batch['conversations']
|
||||
batched_input_ids = []
|
||||
batched_labels = []
|
||||
|
||||
if batched_tools is None:
|
||||
batched_tools = [None] * len(batched_conv)
|
||||
|
||||
for tools, conv in zip(batched_tools, batched_conv):
|
||||
input_ids, loss_masks = [
|
||||
tokenizer.get_command('[gMASK]'),
|
||||
tokenizer.get_command('sop'),
|
||||
], [False, False]
|
||||
|
||||
if tools is not None:
|
||||
raise NotImplementedError()
|
||||
|
||||
for message in conv:
|
||||
if message['role'] in ('system', 'user'):
|
||||
loss_mask_val = False
|
||||
else:
|
||||
loss_mask_val = True
|
||||
|
||||
if message['role'] == 'tool':
|
||||
raise NotImplementedError()
|
||||
else:
|
||||
new_input_ids = tokenizer.build_single_message(
|
||||
message['role'], '', message['content']
|
||||
)
|
||||
new_loss_masks = [loss_mask_val] * len(new_input_ids)
|
||||
|
||||
input_ids += new_input_ids
|
||||
loss_masks += new_loss_masks
|
||||
|
||||
input_ids.append(tokenizer.eos_token_id)
|
||||
loss_masks = [False, *loss_masks]
|
||||
labels = []
|
||||
for input_id, mask in zip(input_ids, loss_masks):
|
||||
if mask:
|
||||
labels.append(input_id)
|
||||
else:
|
||||
labels.append(-100)
|
||||
max_length = max_input_length + max_output_length + 1
|
||||
batched_input_ids.append(input_ids[:max_length])
|
||||
batched_labels.append(labels[:max_length])
|
||||
return {'input_ids': batched_input_ids, 'labels': batched_labels}
|
||||
|
||||
|
||||
def process_batch_eval(
|
||||
batch: Mapping[str, Sequence],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
max_input_length: int,
|
||||
max_output_length: int,
|
||||
) -> dict[str, list]:
|
||||
batched_tools = batch.get('tools', None)
|
||||
batched_conv = batch['conversations']
|
||||
batched_input_ids = []
|
||||
# To avoid computing loss, we do not provide the `labels` field in the input dictionary.
|
||||
batched_output_ids = []
|
||||
|
||||
if batched_tools is None:
|
||||
batched_tools = [None] * len(batched_conv)
|
||||
|
||||
for tools, conv in zip(batched_tools, batched_conv):
|
||||
input_ids = [
|
||||
tokenizer.get_command('[gMASK]'),
|
||||
tokenizer.get_command('sop'),
|
||||
]
|
||||
|
||||
if tools is not None:
|
||||
raise NotImplementedError()
|
||||
|
||||
for message in conv:
|
||||
if len(input_ids) >= max_input_length:
|
||||
break
|
||||
if message['role'] == 'tool':
|
||||
raise NotImplementedError()
|
||||
else:
|
||||
new_input_ids = tokenizer.build_single_message(
|
||||
message['role'], '', message['content']
|
||||
)
|
||||
if message['role'] == 'assistant':
|
||||
output_prompt, output_ids = (
|
||||
new_input_ids[:1],
|
||||
new_input_ids[1:],
|
||||
)
|
||||
output_ids.append(tokenizer.eos_token_id)
|
||||
batched_input_ids.append(
|
||||
input_ids[:max_input_length] + output_prompt[:1]
|
||||
)
|
||||
batched_output_ids.append(output_ids[:max_output_length])
|
||||
input_ids += new_input_ids
|
||||
return {'input_ids': batched_input_ids, 'output_ids': batched_output_ids}
|
||||
|
||||
|
||||
# Not sure if this is necessary, can set it to half.
|
||||
# If train with cpu, cast all params to fp32 instead of trainable ones.
|
||||
def _prepare_model_for_training(model: nn.Module, use_cpu: bool):
|
||||
for param in model.parameters():
|
||||
if param.requires_grad or use_cpu:
|
||||
param.data = param.data.to(torch.float32)
|
||||
|
||||
|
||||
def load_tokenizer_and_model(
|
||||
model_dir: str,
|
||||
peft_config: Optional[PeftConfig] = None,
|
||||
) -> tuple[PreTrainedTokenizer, nn.Module]:
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
|
||||
if peft_config is not None:
|
||||
if peft_config.peft_type.name == "PREFIX_TUNING":
|
||||
config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True)
|
||||
config.pre_seq_len = peft_config.num_virtual_tokens
|
||||
config.use_cache = False
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_dir,
|
||||
trust_remote_code=True,
|
||||
config=config,
|
||||
)
|
||||
if peft_config.peft_type.name == "LORA":
|
||||
# Add below L417 to enable accelerator to schedule model to Intel Arc XPU
|
||||
os.environ["ACCELERATE_USE_XPU"] = "true"
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_dir,
|
||||
trust_remote_code=True,
|
||||
empty_init=False,
|
||||
use_cache=False,
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
model = get_peft_model(model, peft_config)
|
||||
model.print_trainable_parameters()
|
||||
print(model)
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_dir,
|
||||
trust_remote_code=True,
|
||||
empty_init=False,
|
||||
use_cache=False
|
||||
)
|
||||
print_model_size(model)
|
||||
return tokenizer, model
|
||||
|
||||
|
||||
def compute_metrics(eval_preds: EvalPrediction, tokenizer: PreTrainedTokenizer):
|
||||
batched_pred_ids, batched_label_ids = eval_preds
|
||||
|
||||
metrics_dct = {'rouge-1': [], 'rouge-2': [], 'rouge-l': [], 'bleu-4': []}
|
||||
for pred_ids, label_ids in zip(batched_pred_ids, batched_label_ids):
|
||||
pred_txt = tokenizer.decode(pred_ids).strip()
|
||||
label_txt = tokenizer.decode(label_ids).strip()
|
||||
pred_tokens = list(jieba.cut(pred_txt))
|
||||
label_tokens = list(jieba.cut(label_txt))
|
||||
rouge = Rouge()
|
||||
scores = rouge.get_scores(' '.join(pred_tokens), ' '.join(label_tokens))
|
||||
for k, v in scores[0].items():
|
||||
metrics_dct[k].append(round(v['f'] * 100, 4))
|
||||
metrics_dct['bleu-4'].append(
|
||||
sentence_bleu(
|
||||
[label_tokens],
|
||||
pred_tokens,
|
||||
smoothing_function=SmoothingFunction().method3,
|
||||
)
|
||||
)
|
||||
return {k: np.mean(v) for k, v in metrics_dct.items()}
|
||||
|
||||
|
||||
@app.command()
|
||||
def main(
|
||||
data_dir: Annotated[str, typer.Argument(help='')],
|
||||
model_dir: Annotated[
|
||||
str,
|
||||
typer.Argument(
|
||||
help='A string that specifies the model id of a pretrained model configuration hosted on huggingface.co, or a path to a directory containing a model configuration file.'
|
||||
),
|
||||
],
|
||||
config_file: Annotated[str, typer.Argument(help='')],
|
||||
# Add below L474, which is path of deepspeed config file to enable finetuning on 2 Intel Arc XPU cards
|
||||
deepspeed_config_file: Annotated[str, typer.Argument(default='', help='if specified, will apply data parallel')]
|
||||
auto_resume_from_checkpoint: str = typer.Argument(
|
||||
default='',
|
||||
help='If entered as yes, automatically use the latest save checkpoint. If it is a numerical example 12 15, use the corresponding save checkpoint. If the input is no, restart training'
|
||||
)
|
||||
):
|
||||
ft_config = FinetuningConfig.from_file(config_file)
|
||||
tokenizer, model = load_tokenizer_and_model(model_dir, peft_config=ft_config.peft_config)
|
||||
data_manager = DataManager(data_dir, ft_config.data_config)
|
||||
|
||||
train_dataset = data_manager.get_dataset(
|
||||
Split.TRAIN,
|
||||
functools.partial(
|
||||
process_batch,
|
||||
tokenizer=tokenizer,
|
||||
max_input_length=ft_config.max_input_length,
|
||||
max_output_length=ft_config.max_output_length,
|
||||
),
|
||||
batched=True,
|
||||
)
|
||||
print('train_dataset:', train_dataset)
|
||||
val_dataset = data_manager.get_dataset(
|
||||
Split.VALIDATION,
|
||||
functools.partial(
|
||||
process_batch_eval,
|
||||
tokenizer=tokenizer,
|
||||
max_input_length=ft_config.max_input_length,
|
||||
max_output_length=ft_config.max_output_length,
|
||||
),
|
||||
batched=True,
|
||||
)
|
||||
if val_dataset is not None:
|
||||
print('val_dataset:', val_dataset)
|
||||
test_dataset = data_manager.get_dataset(
|
||||
Split.TEST,
|
||||
functools.partial(
|
||||
process_batch_eval,
|
||||
tokenizer=tokenizer,
|
||||
max_input_length=ft_config.max_input_length,
|
||||
max_output_length=ft_config.max_output_length,
|
||||
),
|
||||
batched=True,
|
||||
)
|
||||
if test_dataset is not None:
|
||||
print('test_dataset:', test_dataset)
|
||||
|
||||
# checks encoded dataset
|
||||
_sanity_check(
|
||||
train_dataset[0]["input_ids"], train_dataset[0]["labels"], tokenizer
|
||||
)
|
||||
|
||||
# turn model to fp32
|
||||
_prepare_model_for_training(model, ft_config.training_args.use_cpu)
|
||||
|
||||
ft_config.training_args.generation_config.pad_token_id = (
|
||||
tokenizer.pad_token_id
|
||||
)
|
||||
ft_config.training_args.generation_config.eos_token_id = [
|
||||
tokenizer.eos_token_id,
|
||||
tokenizer.get_command('<|user|>'),
|
||||
tokenizer.get_command('<|observation|>'),
|
||||
]
|
||||
model.gradient_checkpointing_enable()
|
||||
model.enable_input_require_grads()
|
||||
|
||||
use_tokenizer = True
|
||||
if ft_config.peft_config is not None:
|
||||
use_tokenizer = False if ft_config.peft_config.peft_type == "LORA" else True
|
||||
|
||||
# Add below L544-L546 to enable finetuning on 2 Intel Arc XPU cards on top of oneccl and deepspeed
|
||||
if deepspeed_config_file is not '':
|
||||
ft_config.training_args.ddp_backend = "ccl"
|
||||
ft_config.training_args.deepspeed = deepspeed_config_file
|
||||
|
||||
trainer = Seq2SeqTrainer(
|
||||
model=model,
|
||||
args=ft_config.training_args,
|
||||
data_collator=DataCollatorForSeq2Seq(
|
||||
tokenizer=tokenizer,
|
||||
padding='longest',
|
||||
return_tensors='pt',
|
||||
),
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=val_dataset.select(list(range(50))),
|
||||
tokenizer=tokenizer if use_tokenizer else None, # LORA does not need tokenizer
|
||||
compute_metrics=functools.partial(compute_metrics, tokenizer=tokenizer),
|
||||
)
|
||||
|
||||
if auto_resume_from_checkpoint.upper() == "" or auto_resume_from_checkpoint is None:
|
||||
trainer.train()
|
||||
else:
|
||||
output_dir = ft_config.training_args.output_dir
|
||||
dirlist = os.listdir(output_dir)
|
||||
checkpoint_sn = 0
|
||||
for checkpoint_str in dirlist:
|
||||
if checkpoint_str.find("eckpoint") > 0 and checkpoint_str.find("tmp") == -1:
|
||||
checkpoint = int(checkpoint_str.replace("checkpoint-", ""))
|
||||
if checkpoint > checkpoint_sn:
|
||||
checkpoint_sn = checkpoint
|
||||
if auto_resume_from_checkpoint.upper() == "YES":
|
||||
if checkpoint_sn > 0:
|
||||
model.gradient_checkpointing_enable()
|
||||
model.enable_input_require_grads()
|
||||
checkpoint_directory = os.path.join(output_dir, "checkpoint-" + str(checkpoint_sn))
|
||||
print("resume checkpoint from checkpoint-" + str(checkpoint_sn))
|
||||
trainer.train(resume_from_checkpoint=checkpoint_directory)
|
||||
else:
|
||||
trainer.train()
|
||||
else:
|
||||
if auto_resume_from_checkpoint.isdigit():
|
||||
if int(auto_resume_from_checkpoint) > 0:
|
||||
checkpoint_sn = int(auto_resume_from_checkpoint)
|
||||
model.gradient_checkpointing_enable()
|
||||
model.enable_input_require_grads()
|
||||
checkpoint_directory = os.path.join(output_dir, "checkpoint-" + str(checkpoint_sn))
|
||||
print("resume checkpoint from checkpoint-" + str(checkpoint_sn))
|
||||
trainer.train(resume_from_checkpoint=checkpoint_directory)
|
||||
else:
|
||||
print(auto_resume_from_checkpoint,
|
||||
"The specified checkpoint sn(" +
|
||||
auto_resume_from_checkpoint +
|
||||
") has not been saved. Please search for the correct chkeckpoint in the model output directory")
|
||||
|
||||
# test stage
|
||||
if test_dataset is not None:
|
||||
trainer.predict(test_dataset)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
app()
|
||||
|
|
@ -0,0 +1,23 @@
|
|||
#
|
||||
# Copyright 2016 The BigDL Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
export BIGDL_CHECK_DUPLICATE_IMPORT=0
|
||||
|
||||
# You can also set the remote model repository to a local model path
|
||||
python lora_finetune_chatglm.py \
|
||||
./AdvertiseGen_fix \
|
||||
THUDM/chatglm3-6b \
|
||||
./lora_config.yaml
|
||||
|
|
@ -0,0 +1,29 @@
|
|||
#
|
||||
# Copyright 2016 The BigDL Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
export MASTER_ADDR=127.0.0.1
|
||||
export OMP_NUM_THREADS=6
|
||||
export FI_PROVIDER=tcp
|
||||
export CCL_ATL_TRANSPORT=ofi
|
||||
export BIGDL_CHECK_DUPLICATE_IMPORT=0
|
||||
|
||||
# You can also set the remote model repository to a local model path
|
||||
mpirun -n 2 \
|
||||
python lora_finetune_chatglm.py \
|
||||
./AdvertiseGen_fix \
|
||||
THUDM/chatglm3-6b \
|
||||
./lora_config.yaml \
|
||||
./deepspeed_config.json
|
||||
|
|
@ -0,0 +1,59 @@
|
|||
#
|
||||
# Copyright 2016 The BigDL Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# This is ported from https://github.com/THUDM/ChatGLM3/blob/main/finetune_demo/lora_finetune.ipynb
|
||||
# L60 is changed to enable users to finish all operations under one working directory
|
||||
|
||||
import json
|
||||
from typing import Union
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def _resolve_path(path: Union[str, Path]) -> Path:
|
||||
return Path(path).expanduser().resolve()
|
||||
|
||||
|
||||
def _mkdir(dir_name: Union[str, Path]):
|
||||
dir_name = _resolve_path(dir_name)
|
||||
if not dir_name.is_dir():
|
||||
dir_name.mkdir(parents=True, exist_ok=False)
|
||||
|
||||
|
||||
def convert_adgen(data_dir: Union[str, Path], save_dir: Union[str, Path]):
|
||||
def _convert(in_file: Path, out_file: Path):
|
||||
_mkdir(out_file.parent)
|
||||
with open(in_file, encoding='utf-8') as fin:
|
||||
with open(out_file, 'wt', encoding='utf-8') as fout:
|
||||
for line in fin:
|
||||
dct = json.loads(line)
|
||||
sample = {'conversations': [{'role': 'user', 'content': dct['content']},
|
||||
{'role': 'assistant', 'content': dct['summary']}]}
|
||||
fout.write(json.dumps(sample, ensure_ascii=False) + '\n')
|
||||
|
||||
data_dir = _resolve_path(data_dir)
|
||||
save_dir = _resolve_path(save_dir)
|
||||
|
||||
train_file = data_dir / 'train.json'
|
||||
if train_file.is_file():
|
||||
out_file = save_dir / train_file.relative_to(data_dir)
|
||||
_convert(train_file, out_file)
|
||||
|
||||
dev_file = data_dir / 'dev.json'
|
||||
if dev_file.is_file():
|
||||
out_file = save_dir / dev_file.relative_to(data_dir)
|
||||
_convert(dev_file, out_file)
|
||||
|
||||
|
||||
convert_adgen('./AdvertiseGen', './AdvertiseGen_fix')
|
||||
|
|
@ -17,7 +17,7 @@ This folder contains examples of running different training mode with IPEX-LLM o
|
|||
|------------|-----------------------------------------------------------------|-----------------------------------------------------------------|
|
||||
| LLaMA 2/3 | [LoRA](LoRA), [QLoRA](QLoRA), [QA-LoRA](QA-LoRA), [ReLora](ReLora) | [HF-PEFT](HF-PEFT), [axolotl](axolotl) |
|
||||
| Mistral | [LoRA](DPO), [QLoRA](DPO) | [DPO](DPO) |
|
||||
| ChatGLM 3 | [QLoRA](QLoRA/alpaca-qlora#3-qlora-finetune) | HF-PEFT |
|
||||
| ChatGLM 3 | [LoRA](LoRA/chatglm_finetune#lora-fine-tuning-on-chatglm3-6b-with-ipex-llm), [QLoRA](QLoRA/alpaca-qlora#3-qlora-finetune) | HF-PEFT |
|
||||
| Qwen-1.5 | [QLoRA](QLoRA/alpaca-qlora#3-qlora-finetune) | HF-PEFT |
|
||||
| Baichuan2 | [QLoRA](QLoRA/alpaca-qlora#3-qlora-finetune) | HF-PEFT |
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue