ChatGLM3-6B LoRA Fine-tuning Demo (#11450)

* ChatGLM3-6B LoRA Fine-tuning Demo

* refine

* refine

* add 2-card deepspeed

* refine format

* add mpi4py and deepspeed install
This commit is contained in:
Heyang Sun 2024-07-01 09:18:39 +08:00 committed by GitHub
parent e000ac90c4
commit 07362ffffc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 927 additions and 1 deletions

View file

@ -0,0 +1,150 @@
# LoRA Fine-Tuning on ChatGLM3-6B with IPEX-LLM
This example ports [ChatGLM3-6B lora_finetune](https://github.com/THUDM/ChatGLM3/blob/main/finetune_demo/lora_finetune.ipynb) demo to IPEX-LLM on [Intel Arc GPU](../../README.md).
### 1. Install
```bash
conda create -n llm python=3.11
conda activate llm
pip install "jieba>=0.42.1"
pip install "ruamel_yaml>=0.18.6"
pip install "rouge_chinese>=1.0.3"
pip install "jupyter>=1.0.0"
pip install "datasets>=2.18.0"
pip install "peft>=0.10.0"
pip install typer
pip install sentencepiece
pip install nltk
pip install "numpy<2.0.0"
pip install "deepspeed==0.13.1"
pip install "mpi4py>=3.1.5"
# below command will install intel_extension_for_pytorch==2.1.10+xpu as default
pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
pip install oneccl_bind_pt==2.1.100 --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
```
### 2. Configures OneAPI Environment Variables
```bash
source /opt/intel/oneapi/setvars.sh
```
### 3. LoRA Fine-Tune on ChatGLM3-6B
First, download the dataset: we use `AdvertiseGen` to finetune ChatGLM3-6B in the following, and please now get it from [Google Drive](https://drive.google.com/file/d/13_vf0xRTQsyneRKdD1bZIr93vBGOczrk/view?usp=sharing) or [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/b3f119a008264b1cabd1/?dl=1), and unzip it in the current directory. Then, process the dataset with the below script:
```bash
python process_advertise_gen_dataset.py
```
Then, './AdvertiseGen' will be converted to './AdvertiseGen_fix'. Now, we have prepared the dataset, and are going to start LoRA fine-tuning on ChatGLM3-6B.
#### 3.1. Fine-Tune with a Single Arc Card
Start the fine-tuning by:
```bash
bash lora_finetuning_on_chatglm3_6b_with_1_arc_card.sh
```
Then, you will get output are as below:
```bash
2024-06-27 13:47:02,680 - root - INFO - intel_extension_for_pytorch auto imported
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████| 7/7 [00:01<00:00, 6.47it/s]
2024-06-27 13:47:03,794 - ipex_llm.transformers.utils - INFO - Converting the current model to bf16 format......
[2024-06-27 13:47:04,105] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to xpu (auto detect)
trainable params: 487,424 || all params: 6,244,071,424 || trainable%: 0.0078
PeftModelForCausalLM(
(base_model): LoraModel(
(model): ChatGLMForConditionalGeneration(
(transformer): ChatGLMModel(
(embedding): Embedding(
(word_embeddings): Embedding(65024, 4096)
)
(rotary_pos_emb): RotaryEmbedding()
(encoder): GLMTransformer(
(layers): ModuleList(
(0-27): 28 x GLMBlock(
(input_layernorm): RMSNorm()
(self_attention): SelfAttention(
(query_key_value): LoraLowBitLinear(
(base_layer): BF16Linear(in_features=4096, out_features=4608, bias=True)
(lora_dropout): ModuleDict(
(default): Dropout(p=0.1, inplace=False)
)
(lora_A): ModuleDict(
(default): Linear(in_features=4096, out_features=2, bias=False)
)
(lora_B): ModuleDict(
(default): Linear(in_features=2, out_features=4608, bias=False)
)
(lora_embedding_A): ParameterDict()
(lora_embedding_B): ParameterDict()
(qa_pool): Identity()
)
(core_attention): CoreAttention(
(attention_dropout): Dropout(p=0.0, inplace=False)
)
(dense): BF16Linear(in_features=4096, out_features=4096, bias=False)
)
(post_attention_layernorm): RMSNorm()
(mlp): MLP(
(dense_h_to_4h): BF16Linear(in_features=4096, out_features=27392, bias=False)
(dense_4h_to_h): BF16Linear(in_features=13696, out_features=4096, bias=False)
)
)
)
(final_layernorm): RMSNorm()
)
(output_layer): BF16Linear(in_features=4096, out_features=65024, bias=False)
)
)
)
)
--> Model
--> model has 0.487424M params
train_dataset: Dataset({
features: ['input_ids', 'labels'],
num_rows: 114599
})
val_dataset: Dataset({
features: ['input_ids', 'output_ids'],
num_rows: 1070
})
test_dataset: Dataset({
features: ['input_ids', 'output_ids'],
num_rows: 1070
})
--> Sanity check
'[gMASK]': 64790 -> -100
'sop': 64792 -> -100
'<|user|>': 64795 -> -100
'': 30910 -> -100
'\n': 13 -> -100
......
# Here it takes time to finish the whole fine-tuning
......
Training completed. Do not forget to share your model on huggingface.co/models =)
{'train_runtime': xxxx.xxxx, 'train_samples_per_second': x.xxx, 'train_steps_per_second': x.xxx, 'train_loss': xx.xx, 'epoch': x.xx}
100%|████████████████████████████████████████████████████████████████████████████████████████████| 3000/3000 [xx:xx<00:00, x.xxit/s]
***** Running Prediction *****
Num examples = 1070
Batch size = 4
100%|██████████████████████████████████████████████████████████████████████████████████████████████| 268/268 [xx:xx<00:00, x.xxs/it]
```
#### 3.2. Fine-Tune with 2 Arc Cards
Start the data-parallel fine-tuning on 2 Intel Arc XPU cards by:
```bash
bash lora_finetuning_on_chatglm3_6b_with_2_arc_cards.sh
```

View file

@ -0,0 +1,15 @@
{
"zero_optimization": {
"stage": 2,
"offload_optimizer": {
"device": "cpu"
},
"contiguous_gradients": true,
"overlap_comm": true
},
"bf16": {
"enabled": true
},
"train_micro_batch_size_per_gpu": "auto",
"gradient_accumulation_steps": "auto"
}

View file

@ -0,0 +1,47 @@
# This is ported from https://github.com/THUDM/ChatGLM3/blob/main/finetune_demo/configs/lora.yaml
data_config:
train_file: train.json
val_file: dev.json
test_file: dev.json
num_proc: 16
max_input_length: 128
max_output_length: 128
training_args:
# see `transformers.Seq2SeqTrainingArguments`
output_dir: ./output
max_steps: 3000
# needed to be fit for the dataset
learning_rate: 5e-5
# settings for data loading
per_device_train_batch_size: 1
dataloader_num_workers: 16
remove_unused_columns: false
# settings for saving checkpoints
save_strategy: steps
save_steps: 500
# settings for logging
log_level: info
logging_strategy: steps
logging_steps: 10
# settings for evaluation
per_device_eval_batch_size: 4
evaluation_strategy: steps
eval_steps: 1000
# settings for optimizer
# adam_epsilon: 1e-6
# uncomment the following line to detect nan or inf values
# debug: underflow_overflow
predict_with_generate: true
# see `transformers.GenerationConfig`
generation_config:
max_new_tokens: 128
# set your absolute deepspeed path here
#deepspeed: ds_zero_2.json
# set to true if train with cpu.
use_cpu: false
peft_config:
peft_type: LORA
task_type: CAUSAL_LM
r: 2
lora_alpha: 8
lora_dropout: 0.1

View file

@ -0,0 +1,603 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Below 2 lines different from the original example, where transformers are patched with IPEX LLM
from ipex_llm import llm_patch
llm_patch(train=True)
# This below example is ported from https://github.com/THUDM/ChatGLM3/blob/main/finetune_demo/finetune_hf.py
# L417, L474 and L544-L546 are modified to enable the example on Intel Arc
import os
import jieba
import dataclasses as dc
import functools
from collections.abc import Callable, Mapping, Sequence
from pathlib import Path
from typing import Annotated, Any, Optional, Union
import numpy as np
import ruamel.yaml as yaml
import torch
import typer
from datasets import Dataset, DatasetDict, NamedSplit, Split, load_dataset
from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu
from peft import (
PeftConfig,
PeftModelForCausalLM,
get_peft_config
)
from rouge_chinese import Rouge
from torch import nn
from transformers import (
AutoTokenizer,
EvalPrediction,
GenerationConfig,
PreTrainedModel,
PreTrainedTokenizer,
PreTrainedTokenizerFast,
Seq2SeqTrainingArguments, AutoConfig,
)
from transformers import DataCollatorForSeq2Seq as _DataCollatorForSeq2Seq
from transformers import Seq2SeqTrainer as _Seq2SeqTrainer
ModelType = Union[PreTrainedModel, PeftModelForCausalLM]
TokenizerType = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
app = typer.Typer(pretty_exceptions_show_locals=False)
class DataCollatorForSeq2Seq(_DataCollatorForSeq2Seq):
def __call__(self, features, return_tensors=None):
output_ids = (
[feature['output_ids'] for feature in features]
if 'output_ids' in features[0].keys()
else None
)
if output_ids is not None:
max_output_length = max(len(out) for out in output_ids)
if self.pad_to_multiple_of is not None:
max_output_length = (
(
max_output_length + self.pad_to_multiple_of - 1) //
self.pad_to_multiple_of * self.pad_to_multiple_of
)
for feature in features:
remainder = [self.tokenizer.pad_token_id] * (
max_output_length - len(feature['output_ids'])
)
if isinstance(feature['output_ids'], list):
feature['output_ids'] = feature['output_ids'] + remainder
else:
feature['output_ids'] = np.concatenate(
[feature['output_ids'], remainder]
).astype(np.int64)
return super().__call__(features, return_tensors)
class Seq2SeqTrainer(_Seq2SeqTrainer):
def prediction_step(
self,
model: nn.Module,
inputs: dict[str, Any],
prediction_loss_only: bool,
ignore_keys=None,
**gen_kwargs,
) -> tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
if self.args.predict_with_generate:
output_ids = inputs.pop('output_ids')
input_ids = inputs['input_ids']
loss, generated_tokens, labels = super().prediction_step(
model, inputs, prediction_loss_only, ignore_keys, **gen_kwargs
)
generated_tokens = generated_tokens[:, input_ids.size()[1]:]
if self.args.predict_with_generate:
labels = output_ids
return loss, generated_tokens, labels
def _resolve_path(path: Union[str, Path]) -> Path:
return Path(path).expanduser().resolve()
def _sanity_check(
input_ids: Sequence[int],
output_ids: Sequence[int],
tokenizer: PreTrainedTokenizer,
):
print('--> Sanity check')
for in_id, out_id in zip(input_ids, output_ids):
if in_id == 0:
continue
if in_id in tokenizer.tokenizer.index_special_tokens:
in_text = tokenizer.tokenizer.index_special_tokens[in_id]
else:
in_text = tokenizer.decode([in_id])
print(f'{repr(in_text):>20}: {in_id} -> {out_id}')
@functools.cache
def _get_yaml_parser() -> yaml.YAML:
parser = yaml.YAML(typ='safe', pure=True)
parser.indent(mapping=2, offset=2, sequence=4)
parser.default_flow_style = False
return parser
@dc.dataclass
class DataConfig(object):
train_file: str
val_file: Optional[str] = None
test_file: Optional[str] = None
num_proc: Optional[int] = None
@property
def data_format(self) -> str:
return Path(self.train_file).suffix
@property
def data_files(self) -> dict[NamedSplit, str]:
return {
split: data_file
for split, data_file in zip(
[Split.TRAIN, Split.VALIDATION, Split.TEST],
[self.train_file, self.val_file, self.test_file],
)
if data_file is not None
}
@dc.dataclass
class FinetuningConfig(object):
data_config: DataConfig
max_input_length: int
max_output_length: int
training_args: Seq2SeqTrainingArguments = dc.field(
default_factory=lambda: Seq2SeqTrainingArguments(output_dir='./output')
)
peft_config: Optional[PeftConfig] = None
def __post_init__(self):
if not self.training_args.do_eval or self.data_config.val_file is None:
# skips the evaluation stage when `do_eval` or `eval_file` is not provided
self.training_args.do_eval = False
self.training_args.evaluation_strategy = 'no'
self.data_config.val_file = None
else:
self.training_args.per_device_eval_batch_size = (
self.training_args.per_device_eval_batch_size
or self.training_args.per_device_train_batch_size
)
@classmethod
def from_dict(cls, **kwargs) -> 'FinetuningConfig':
training_args = kwargs.get('training_args', None)
if training_args is not None and not isinstance(
training_args, Seq2SeqTrainingArguments
):
gen_config = training_args.get('generation_config')
# TODO: a bit hacky
if not isinstance(gen_config, GenerationConfig):
training_args['generation_config'] = GenerationConfig(
**gen_config
)
kwargs['training_args'] = Seq2SeqTrainingArguments(**training_args)
data_config = kwargs.get('data_config')
if not isinstance(data_config, DataConfig):
kwargs['data_config'] = DataConfig(**data_config)
peft_config = kwargs.get('peft_config', None)
if peft_config is not None and not isinstance(peft_config, PeftConfig):
kwargs['peft_config'] = get_peft_config(peft_config)
return cls(**kwargs)
@classmethod
def from_file(cls, path: Union[str, Path]) -> 'FinetuningConfig':
path = _resolve_path(path)
kwargs = _get_yaml_parser().load(path)
return cls.from_dict(**kwargs)
def _load_datasets(
data_dir: Path,
data_format: str,
data_files: dict[NamedSplit, str],
num_proc: Optional[int],
) -> DatasetDict:
if data_format in ('.csv', '.json', '.jsonl'):
dataset_dct = load_dataset(
data_format[1:],
data_dir=data_dir,
data_files=data_files,
num_proc=num_proc,
)
else:
err_msg = f"Cannot load dataset in the '{data_format}' format."
raise NotImplementedError(err_msg)
return dataset_dct
class DataManager(object):
def __init__(self, data_dir: str, data_config: DataConfig):
self._num_proc = data_config.num_proc
self._dataset_dct = _load_datasets(
_resolve_path(data_dir),
data_config.data_format,
data_config.data_files,
self._num_proc,
)
def _get_dataset(self, split: NamedSplit) -> Optional[Dataset]:
return self._dataset_dct.get(split, None)
def get_dataset(
self,
split: NamedSplit,
process_fn: Callable[[dict[str, Any]], dict[str, Any]],
batched: bool = True,
remove_orig_columns: bool = True,
) -> Optional[Dataset]:
orig_dataset = self._get_dataset(split)
if orig_dataset is None:
return
if remove_orig_columns:
remove_columns = orig_dataset.column_names
else:
remove_columns = None
return orig_dataset.map(
process_fn,
batched=batched,
remove_columns=remove_columns,
num_proc=self._num_proc,
)
def print_model_size(model: PreTrainedModel):
print("--> Model")
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\n--> model has {total_params / 1e6}M params\n")
def process_batch(
batch: Mapping[str, Sequence],
tokenizer: PreTrainedTokenizer,
max_input_length: int,
max_output_length: int,
) -> dict[str, list]:
batched_tools = batch.get('tools', None)
batched_conv = batch['conversations']
batched_input_ids = []
batched_labels = []
if batched_tools is None:
batched_tools = [None] * len(batched_conv)
for tools, conv in zip(batched_tools, batched_conv):
input_ids, loss_masks = [
tokenizer.get_command('[gMASK]'),
tokenizer.get_command('sop'),
], [False, False]
if tools is not None:
raise NotImplementedError()
for message in conv:
if message['role'] in ('system', 'user'):
loss_mask_val = False
else:
loss_mask_val = True
if message['role'] == 'tool':
raise NotImplementedError()
else:
new_input_ids = tokenizer.build_single_message(
message['role'], '', message['content']
)
new_loss_masks = [loss_mask_val] * len(new_input_ids)
input_ids += new_input_ids
loss_masks += new_loss_masks
input_ids.append(tokenizer.eos_token_id)
loss_masks = [False, *loss_masks]
labels = []
for input_id, mask in zip(input_ids, loss_masks):
if mask:
labels.append(input_id)
else:
labels.append(-100)
max_length = max_input_length + max_output_length + 1
batched_input_ids.append(input_ids[:max_length])
batched_labels.append(labels[:max_length])
return {'input_ids': batched_input_ids, 'labels': batched_labels}
def process_batch_eval(
batch: Mapping[str, Sequence],
tokenizer: PreTrainedTokenizer,
max_input_length: int,
max_output_length: int,
) -> dict[str, list]:
batched_tools = batch.get('tools', None)
batched_conv = batch['conversations']
batched_input_ids = []
# To avoid computing loss, we do not provide the `labels` field in the input dictionary.
batched_output_ids = []
if batched_tools is None:
batched_tools = [None] * len(batched_conv)
for tools, conv in zip(batched_tools, batched_conv):
input_ids = [
tokenizer.get_command('[gMASK]'),
tokenizer.get_command('sop'),
]
if tools is not None:
raise NotImplementedError()
for message in conv:
if len(input_ids) >= max_input_length:
break
if message['role'] == 'tool':
raise NotImplementedError()
else:
new_input_ids = tokenizer.build_single_message(
message['role'], '', message['content']
)
if message['role'] == 'assistant':
output_prompt, output_ids = (
new_input_ids[:1],
new_input_ids[1:],
)
output_ids.append(tokenizer.eos_token_id)
batched_input_ids.append(
input_ids[:max_input_length] + output_prompt[:1]
)
batched_output_ids.append(output_ids[:max_output_length])
input_ids += new_input_ids
return {'input_ids': batched_input_ids, 'output_ids': batched_output_ids}
# Not sure if this is necessary, can set it to half.
# If train with cpu, cast all params to fp32 instead of trainable ones.
def _prepare_model_for_training(model: nn.Module, use_cpu: bool):
for param in model.parameters():
if param.requires_grad or use_cpu:
param.data = param.data.to(torch.float32)
def load_tokenizer_and_model(
model_dir: str,
peft_config: Optional[PeftConfig] = None,
) -> tuple[PreTrainedTokenizer, nn.Module]:
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
if peft_config is not None:
if peft_config.peft_type.name == "PREFIX_TUNING":
config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True)
config.pre_seq_len = peft_config.num_virtual_tokens
config.use_cache = False
model = AutoModelForCausalLM.from_pretrained(
model_dir,
trust_remote_code=True,
config=config,
)
if peft_config.peft_type.name == "LORA":
# Add below L417 to enable accelerator to schedule model to Intel Arc XPU
os.environ["ACCELERATE_USE_XPU"] = "true"
model = AutoModelForCausalLM.from_pretrained(
model_dir,
trust_remote_code=True,
empty_init=False,
use_cache=False,
torch_dtype=torch.bfloat16
)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
print(model)
else:
model = AutoModelForCausalLM.from_pretrained(
model_dir,
trust_remote_code=True,
empty_init=False,
use_cache=False
)
print_model_size(model)
return tokenizer, model
def compute_metrics(eval_preds: EvalPrediction, tokenizer: PreTrainedTokenizer):
batched_pred_ids, batched_label_ids = eval_preds
metrics_dct = {'rouge-1': [], 'rouge-2': [], 'rouge-l': [], 'bleu-4': []}
for pred_ids, label_ids in zip(batched_pred_ids, batched_label_ids):
pred_txt = tokenizer.decode(pred_ids).strip()
label_txt = tokenizer.decode(label_ids).strip()
pred_tokens = list(jieba.cut(pred_txt))
label_tokens = list(jieba.cut(label_txt))
rouge = Rouge()
scores = rouge.get_scores(' '.join(pred_tokens), ' '.join(label_tokens))
for k, v in scores[0].items():
metrics_dct[k].append(round(v['f'] * 100, 4))
metrics_dct['bleu-4'].append(
sentence_bleu(
[label_tokens],
pred_tokens,
smoothing_function=SmoothingFunction().method3,
)
)
return {k: np.mean(v) for k, v in metrics_dct.items()}
@app.command()
def main(
data_dir: Annotated[str, typer.Argument(help='')],
model_dir: Annotated[
str,
typer.Argument(
help='A string that specifies the model id of a pretrained model configuration hosted on huggingface.co, or a path to a directory containing a model configuration file.'
),
],
config_file: Annotated[str, typer.Argument(help='')],
# Add below L474, which is path of deepspeed config file to enable finetuning on 2 Intel Arc XPU cards
deepspeed_config_file: Annotated[str, typer.Argument(default='', help='if specified, will apply data parallel')]
auto_resume_from_checkpoint: str = typer.Argument(
default='',
help='If entered as yes, automatically use the latest save checkpoint. If it is a numerical example 12 15, use the corresponding save checkpoint. If the input is no, restart training'
)
):
ft_config = FinetuningConfig.from_file(config_file)
tokenizer, model = load_tokenizer_and_model(model_dir, peft_config=ft_config.peft_config)
data_manager = DataManager(data_dir, ft_config.data_config)
train_dataset = data_manager.get_dataset(
Split.TRAIN,
functools.partial(
process_batch,
tokenizer=tokenizer,
max_input_length=ft_config.max_input_length,
max_output_length=ft_config.max_output_length,
),
batched=True,
)
print('train_dataset:', train_dataset)
val_dataset = data_manager.get_dataset(
Split.VALIDATION,
functools.partial(
process_batch_eval,
tokenizer=tokenizer,
max_input_length=ft_config.max_input_length,
max_output_length=ft_config.max_output_length,
),
batched=True,
)
if val_dataset is not None:
print('val_dataset:', val_dataset)
test_dataset = data_manager.get_dataset(
Split.TEST,
functools.partial(
process_batch_eval,
tokenizer=tokenizer,
max_input_length=ft_config.max_input_length,
max_output_length=ft_config.max_output_length,
),
batched=True,
)
if test_dataset is not None:
print('test_dataset:', test_dataset)
# checks encoded dataset
_sanity_check(
train_dataset[0]["input_ids"], train_dataset[0]["labels"], tokenizer
)
# turn model to fp32
_prepare_model_for_training(model, ft_config.training_args.use_cpu)
ft_config.training_args.generation_config.pad_token_id = (
tokenizer.pad_token_id
)
ft_config.training_args.generation_config.eos_token_id = [
tokenizer.eos_token_id,
tokenizer.get_command('<|user|>'),
tokenizer.get_command('<|observation|>'),
]
model.gradient_checkpointing_enable()
model.enable_input_require_grads()
use_tokenizer = True
if ft_config.peft_config is not None:
use_tokenizer = False if ft_config.peft_config.peft_type == "LORA" else True
# Add below L544-L546 to enable finetuning on 2 Intel Arc XPU cards on top of oneccl and deepspeed
if deepspeed_config_file is not '':
ft_config.training_args.ddp_backend = "ccl"
ft_config.training_args.deepspeed = deepspeed_config_file
trainer = Seq2SeqTrainer(
model=model,
args=ft_config.training_args,
data_collator=DataCollatorForSeq2Seq(
tokenizer=tokenizer,
padding='longest',
return_tensors='pt',
),
train_dataset=train_dataset,
eval_dataset=val_dataset.select(list(range(50))),
tokenizer=tokenizer if use_tokenizer else None, # LORA does not need tokenizer
compute_metrics=functools.partial(compute_metrics, tokenizer=tokenizer),
)
if auto_resume_from_checkpoint.upper() == "" or auto_resume_from_checkpoint is None:
trainer.train()
else:
output_dir = ft_config.training_args.output_dir
dirlist = os.listdir(output_dir)
checkpoint_sn = 0
for checkpoint_str in dirlist:
if checkpoint_str.find("eckpoint") > 0 and checkpoint_str.find("tmp") == -1:
checkpoint = int(checkpoint_str.replace("checkpoint-", ""))
if checkpoint > checkpoint_sn:
checkpoint_sn = checkpoint
if auto_resume_from_checkpoint.upper() == "YES":
if checkpoint_sn > 0:
model.gradient_checkpointing_enable()
model.enable_input_require_grads()
checkpoint_directory = os.path.join(output_dir, "checkpoint-" + str(checkpoint_sn))
print("resume checkpoint from checkpoint-" + str(checkpoint_sn))
trainer.train(resume_from_checkpoint=checkpoint_directory)
else:
trainer.train()
else:
if auto_resume_from_checkpoint.isdigit():
if int(auto_resume_from_checkpoint) > 0:
checkpoint_sn = int(auto_resume_from_checkpoint)
model.gradient_checkpointing_enable()
model.enable_input_require_grads()
checkpoint_directory = os.path.join(output_dir, "checkpoint-" + str(checkpoint_sn))
print("resume checkpoint from checkpoint-" + str(checkpoint_sn))
trainer.train(resume_from_checkpoint=checkpoint_directory)
else:
print(auto_resume_from_checkpoint,
"The specified checkpoint sn(" +
auto_resume_from_checkpoint +
") has not been saved. Please search for the correct chkeckpoint in the model output directory")
# test stage
if test_dataset is not None:
trainer.predict(test_dataset)
if __name__ == '__main__':
app()

View file

@ -0,0 +1,23 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
export BIGDL_CHECK_DUPLICATE_IMPORT=0
# You can also set the remote model repository to a local model path
python lora_finetune_chatglm.py \
./AdvertiseGen_fix \
THUDM/chatglm3-6b \
./lora_config.yaml

View file

@ -0,0 +1,29 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
export MASTER_ADDR=127.0.0.1
export OMP_NUM_THREADS=6
export FI_PROVIDER=tcp
export CCL_ATL_TRANSPORT=ofi
export BIGDL_CHECK_DUPLICATE_IMPORT=0
# You can also set the remote model repository to a local model path
mpirun -n 2 \
python lora_finetune_chatglm.py \
./AdvertiseGen_fix \
THUDM/chatglm3-6b \
./lora_config.yaml \
./deepspeed_config.json

View file

@ -0,0 +1,59 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# This is ported from https://github.com/THUDM/ChatGLM3/blob/main/finetune_demo/lora_finetune.ipynb
# L60 is changed to enable users to finish all operations under one working directory
import json
from typing import Union
from pathlib import Path
def _resolve_path(path: Union[str, Path]) -> Path:
return Path(path).expanduser().resolve()
def _mkdir(dir_name: Union[str, Path]):
dir_name = _resolve_path(dir_name)
if not dir_name.is_dir():
dir_name.mkdir(parents=True, exist_ok=False)
def convert_adgen(data_dir: Union[str, Path], save_dir: Union[str, Path]):
def _convert(in_file: Path, out_file: Path):
_mkdir(out_file.parent)
with open(in_file, encoding='utf-8') as fin:
with open(out_file, 'wt', encoding='utf-8') as fout:
for line in fin:
dct = json.loads(line)
sample = {'conversations': [{'role': 'user', 'content': dct['content']},
{'role': 'assistant', 'content': dct['summary']}]}
fout.write(json.dumps(sample, ensure_ascii=False) + '\n')
data_dir = _resolve_path(data_dir)
save_dir = _resolve_path(save_dir)
train_file = data_dir / 'train.json'
if train_file.is_file():
out_file = save_dir / train_file.relative_to(data_dir)
_convert(train_file, out_file)
dev_file = data_dir / 'dev.json'
if dev_file.is_file():
out_file = save_dir / dev_file.relative_to(data_dir)
_convert(dev_file, out_file)
convert_adgen('./AdvertiseGen', './AdvertiseGen_fix')

View file

@ -17,7 +17,7 @@ This folder contains examples of running different training mode with IPEX-LLM o
|------------|-----------------------------------------------------------------|-----------------------------------------------------------------|
| LLaMA 2/3 | [LoRA](LoRA), [QLoRA](QLoRA), [QA-LoRA](QA-LoRA), [ReLora](ReLora) | [HF-PEFT](HF-PEFT), [axolotl](axolotl) |
| Mistral | [LoRA](DPO), [QLoRA](DPO) | [DPO](DPO) |
| ChatGLM 3 | [QLoRA](QLoRA/alpaca-qlora#3-qlora-finetune) | HF-PEFT |
| ChatGLM 3 | [LoRA](LoRA/chatglm_finetune#lora-fine-tuning-on-chatglm3-6b-with-ipex-llm), [QLoRA](QLoRA/alpaca-qlora#3-qlora-finetune) | HF-PEFT |
| Qwen-1.5 | [QLoRA](QLoRA/alpaca-qlora#3-qlora-finetune) | HF-PEFT |
| Baichuan2 | [QLoRA](QLoRA/alpaca-qlora#3-qlora-finetune) | HF-PEFT |