ChatGLM3-6B LoRA Fine-tuning Demo (#11450)
* ChatGLM3-6B LoRA Fine-tuning Demo * refine * refine * add 2-card deepspeed * refine format * add mpi4py and deepspeed install
This commit is contained in:
parent
e000ac90c4
commit
07362ffffc
8 changed files with 927 additions and 1 deletions
|
|
@ -0,0 +1,150 @@
|
||||||
|
# LoRA Fine-Tuning on ChatGLM3-6B with IPEX-LLM
|
||||||
|
|
||||||
|
This example ports [ChatGLM3-6B lora_finetune](https://github.com/THUDM/ChatGLM3/blob/main/finetune_demo/lora_finetune.ipynb) demo to IPEX-LLM on [Intel Arc GPU](../../README.md).
|
||||||
|
|
||||||
|
### 1. Install
|
||||||
|
|
||||||
|
```bash
|
||||||
|
conda create -n llm python=3.11
|
||||||
|
conda activate llm
|
||||||
|
pip install "jieba>=0.42.1"
|
||||||
|
pip install "ruamel_yaml>=0.18.6"
|
||||||
|
pip install "rouge_chinese>=1.0.3"
|
||||||
|
pip install "jupyter>=1.0.0"
|
||||||
|
pip install "datasets>=2.18.0"
|
||||||
|
pip install "peft>=0.10.0"
|
||||||
|
pip install typer
|
||||||
|
pip install sentencepiece
|
||||||
|
pip install nltk
|
||||||
|
pip install "numpy<2.0.0"
|
||||||
|
pip install "deepspeed==0.13.1"
|
||||||
|
pip install "mpi4py>=3.1.5"
|
||||||
|
# below command will install intel_extension_for_pytorch==2.1.10+xpu as default
|
||||||
|
pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
|
||||||
|
pip install oneccl_bind_pt==2.1.100 --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Configures OneAPI Environment Variables
|
||||||
|
```bash
|
||||||
|
source /opt/intel/oneapi/setvars.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. LoRA Fine-Tune on ChatGLM3-6B
|
||||||
|
|
||||||
|
First, download the dataset: we use `AdvertiseGen` to finetune ChatGLM3-6B in the following, and please now get it from [Google Drive](https://drive.google.com/file/d/13_vf0xRTQsyneRKdD1bZIr93vBGOczrk/view?usp=sharing) or [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/b3f119a008264b1cabd1/?dl=1), and unzip it in the current directory. Then, process the dataset with the below script:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python process_advertise_gen_dataset.py
|
||||||
|
```
|
||||||
|
|
||||||
|
Then, './AdvertiseGen' will be converted to './AdvertiseGen_fix'. Now, we have prepared the dataset, and are going to start LoRA fine-tuning on ChatGLM3-6B.
|
||||||
|
|
||||||
|
#### 3.1. Fine-Tune with a Single Arc Card
|
||||||
|
|
||||||
|
Start the fine-tuning by:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
bash lora_finetuning_on_chatglm3_6b_with_1_arc_card.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
Then, you will get output are as below:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
2024-06-27 13:47:02,680 - root - INFO - intel_extension_for_pytorch auto imported
|
||||||
|
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████| 7/7 [00:01<00:00, 6.47it/s]
|
||||||
|
2024-06-27 13:47:03,794 - ipex_llm.transformers.utils - INFO - Converting the current model to bf16 format......
|
||||||
|
[2024-06-27 13:47:04,105] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to xpu (auto detect)
|
||||||
|
trainable params: 487,424 || all params: 6,244,071,424 || trainable%: 0.0078
|
||||||
|
PeftModelForCausalLM(
|
||||||
|
(base_model): LoraModel(
|
||||||
|
(model): ChatGLMForConditionalGeneration(
|
||||||
|
(transformer): ChatGLMModel(
|
||||||
|
(embedding): Embedding(
|
||||||
|
(word_embeddings): Embedding(65024, 4096)
|
||||||
|
)
|
||||||
|
(rotary_pos_emb): RotaryEmbedding()
|
||||||
|
(encoder): GLMTransformer(
|
||||||
|
(layers): ModuleList(
|
||||||
|
(0-27): 28 x GLMBlock(
|
||||||
|
(input_layernorm): RMSNorm()
|
||||||
|
(self_attention): SelfAttention(
|
||||||
|
(query_key_value): LoraLowBitLinear(
|
||||||
|
(base_layer): BF16Linear(in_features=4096, out_features=4608, bias=True)
|
||||||
|
(lora_dropout): ModuleDict(
|
||||||
|
(default): Dropout(p=0.1, inplace=False)
|
||||||
|
)
|
||||||
|
(lora_A): ModuleDict(
|
||||||
|
(default): Linear(in_features=4096, out_features=2, bias=False)
|
||||||
|
)
|
||||||
|
(lora_B): ModuleDict(
|
||||||
|
(default): Linear(in_features=2, out_features=4608, bias=False)
|
||||||
|
)
|
||||||
|
(lora_embedding_A): ParameterDict()
|
||||||
|
(lora_embedding_B): ParameterDict()
|
||||||
|
(qa_pool): Identity()
|
||||||
|
)
|
||||||
|
(core_attention): CoreAttention(
|
||||||
|
(attention_dropout): Dropout(p=0.0, inplace=False)
|
||||||
|
)
|
||||||
|
(dense): BF16Linear(in_features=4096, out_features=4096, bias=False)
|
||||||
|
)
|
||||||
|
(post_attention_layernorm): RMSNorm()
|
||||||
|
(mlp): MLP(
|
||||||
|
(dense_h_to_4h): BF16Linear(in_features=4096, out_features=27392, bias=False)
|
||||||
|
(dense_4h_to_h): BF16Linear(in_features=13696, out_features=4096, bias=False)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
(final_layernorm): RMSNorm()
|
||||||
|
)
|
||||||
|
(output_layer): BF16Linear(in_features=4096, out_features=65024, bias=False)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
--> Model
|
||||||
|
|
||||||
|
--> model has 0.487424M params
|
||||||
|
|
||||||
|
train_dataset: Dataset({
|
||||||
|
features: ['input_ids', 'labels'],
|
||||||
|
num_rows: 114599
|
||||||
|
})
|
||||||
|
val_dataset: Dataset({
|
||||||
|
features: ['input_ids', 'output_ids'],
|
||||||
|
num_rows: 1070
|
||||||
|
})
|
||||||
|
test_dataset: Dataset({
|
||||||
|
features: ['input_ids', 'output_ids'],
|
||||||
|
num_rows: 1070
|
||||||
|
})
|
||||||
|
--> Sanity check
|
||||||
|
'[gMASK]': 64790 -> -100
|
||||||
|
'sop': 64792 -> -100
|
||||||
|
'<|user|>': 64795 -> -100
|
||||||
|
'': 30910 -> -100
|
||||||
|
'\n': 13 -> -100
|
||||||
|
......
|
||||||
|
|
||||||
|
# Here it takes time to finish the whole fine-tuning
|
||||||
|
|
||||||
|
......
|
||||||
|
|
||||||
|
Training completed. Do not forget to share your model on huggingface.co/models =)
|
||||||
|
|
||||||
|
|
||||||
|
{'train_runtime': xxxx.xxxx, 'train_samples_per_second': x.xxx, 'train_steps_per_second': x.xxx, 'train_loss': xx.xx, 'epoch': x.xx}
|
||||||
|
100%|████████████████████████████████████████████████████████████████████████████████████████████| 3000/3000 [xx:xx<00:00, x.xxit/s]
|
||||||
|
***** Running Prediction *****
|
||||||
|
Num examples = 1070
|
||||||
|
Batch size = 4
|
||||||
|
100%|██████████████████████████████████████████████████████████████████████████████████████████████| 268/268 [xx:xx<00:00, x.xxs/it]
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 3.2. Fine-Tune with 2 Arc Cards
|
||||||
|
|
||||||
|
Start the data-parallel fine-tuning on 2 Intel Arc XPU cards by:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
bash lora_finetuning_on_chatglm3_6b_with_2_arc_cards.sh
|
||||||
|
```
|
||||||
|
|
@ -0,0 +1,15 @@
|
||||||
|
{
|
||||||
|
"zero_optimization": {
|
||||||
|
"stage": 2,
|
||||||
|
"offload_optimizer": {
|
||||||
|
"device": "cpu"
|
||||||
|
},
|
||||||
|
"contiguous_gradients": true,
|
||||||
|
"overlap_comm": true
|
||||||
|
},
|
||||||
|
"bf16": {
|
||||||
|
"enabled": true
|
||||||
|
},
|
||||||
|
"train_micro_batch_size_per_gpu": "auto",
|
||||||
|
"gradient_accumulation_steps": "auto"
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,47 @@
|
||||||
|
# This is ported from https://github.com/THUDM/ChatGLM3/blob/main/finetune_demo/configs/lora.yaml
|
||||||
|
data_config:
|
||||||
|
train_file: train.json
|
||||||
|
val_file: dev.json
|
||||||
|
test_file: dev.json
|
||||||
|
num_proc: 16
|
||||||
|
max_input_length: 128
|
||||||
|
max_output_length: 128
|
||||||
|
training_args:
|
||||||
|
# see `transformers.Seq2SeqTrainingArguments`
|
||||||
|
output_dir: ./output
|
||||||
|
max_steps: 3000
|
||||||
|
# needed to be fit for the dataset
|
||||||
|
learning_rate: 5e-5
|
||||||
|
# settings for data loading
|
||||||
|
per_device_train_batch_size: 1
|
||||||
|
dataloader_num_workers: 16
|
||||||
|
remove_unused_columns: false
|
||||||
|
# settings for saving checkpoints
|
||||||
|
save_strategy: steps
|
||||||
|
save_steps: 500
|
||||||
|
# settings for logging
|
||||||
|
log_level: info
|
||||||
|
logging_strategy: steps
|
||||||
|
logging_steps: 10
|
||||||
|
# settings for evaluation
|
||||||
|
per_device_eval_batch_size: 4
|
||||||
|
evaluation_strategy: steps
|
||||||
|
eval_steps: 1000
|
||||||
|
# settings for optimizer
|
||||||
|
# adam_epsilon: 1e-6
|
||||||
|
# uncomment the following line to detect nan or inf values
|
||||||
|
# debug: underflow_overflow
|
||||||
|
predict_with_generate: true
|
||||||
|
# see `transformers.GenerationConfig`
|
||||||
|
generation_config:
|
||||||
|
max_new_tokens: 128
|
||||||
|
# set your absolute deepspeed path here
|
||||||
|
#deepspeed: ds_zero_2.json
|
||||||
|
# set to true if train with cpu.
|
||||||
|
use_cpu: false
|
||||||
|
peft_config:
|
||||||
|
peft_type: LORA
|
||||||
|
task_type: CAUSAL_LM
|
||||||
|
r: 2
|
||||||
|
lora_alpha: 8
|
||||||
|
lora_dropout: 0.1
|
||||||
|
|
@ -0,0 +1,603 @@
|
||||||
|
#
|
||||||
|
# Copyright 2016 The BigDL Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
# Below 2 lines different from the original example, where transformers are patched with IPEX LLM
|
||||||
|
from ipex_llm import llm_patch
|
||||||
|
llm_patch(train=True)
|
||||||
|
|
||||||
|
# This below example is ported from https://github.com/THUDM/ChatGLM3/blob/main/finetune_demo/finetune_hf.py
|
||||||
|
# L417, L474 and L544-L546 are modified to enable the example on Intel Arc
|
||||||
|
import os
|
||||||
|
import jieba
|
||||||
|
import dataclasses as dc
|
||||||
|
import functools
|
||||||
|
from collections.abc import Callable, Mapping, Sequence
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Annotated, Any, Optional, Union
|
||||||
|
import numpy as np
|
||||||
|
import ruamel.yaml as yaml
|
||||||
|
import torch
|
||||||
|
import typer
|
||||||
|
from datasets import Dataset, DatasetDict, NamedSplit, Split, load_dataset
|
||||||
|
from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu
|
||||||
|
from peft import (
|
||||||
|
PeftConfig,
|
||||||
|
PeftModelForCausalLM,
|
||||||
|
get_peft_config
|
||||||
|
)
|
||||||
|
from rouge_chinese import Rouge
|
||||||
|
from torch import nn
|
||||||
|
from transformers import (
|
||||||
|
AutoTokenizer,
|
||||||
|
EvalPrediction,
|
||||||
|
GenerationConfig,
|
||||||
|
PreTrainedModel,
|
||||||
|
PreTrainedTokenizer,
|
||||||
|
PreTrainedTokenizerFast,
|
||||||
|
Seq2SeqTrainingArguments, AutoConfig,
|
||||||
|
)
|
||||||
|
from transformers import DataCollatorForSeq2Seq as _DataCollatorForSeq2Seq
|
||||||
|
|
||||||
|
from transformers import Seq2SeqTrainer as _Seq2SeqTrainer
|
||||||
|
|
||||||
|
ModelType = Union[PreTrainedModel, PeftModelForCausalLM]
|
||||||
|
TokenizerType = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
|
||||||
|
app = typer.Typer(pretty_exceptions_show_locals=False)
|
||||||
|
|
||||||
|
|
||||||
|
class DataCollatorForSeq2Seq(_DataCollatorForSeq2Seq):
|
||||||
|
def __call__(self, features, return_tensors=None):
|
||||||
|
output_ids = (
|
||||||
|
[feature['output_ids'] for feature in features]
|
||||||
|
if 'output_ids' in features[0].keys()
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
if output_ids is not None:
|
||||||
|
max_output_length = max(len(out) for out in output_ids)
|
||||||
|
if self.pad_to_multiple_of is not None:
|
||||||
|
max_output_length = (
|
||||||
|
(
|
||||||
|
max_output_length + self.pad_to_multiple_of - 1) //
|
||||||
|
self.pad_to_multiple_of * self.pad_to_multiple_of
|
||||||
|
)
|
||||||
|
for feature in features:
|
||||||
|
remainder = [self.tokenizer.pad_token_id] * (
|
||||||
|
max_output_length - len(feature['output_ids'])
|
||||||
|
)
|
||||||
|
if isinstance(feature['output_ids'], list):
|
||||||
|
feature['output_ids'] = feature['output_ids'] + remainder
|
||||||
|
else:
|
||||||
|
feature['output_ids'] = np.concatenate(
|
||||||
|
[feature['output_ids'], remainder]
|
||||||
|
).astype(np.int64)
|
||||||
|
return super().__call__(features, return_tensors)
|
||||||
|
|
||||||
|
|
||||||
|
class Seq2SeqTrainer(_Seq2SeqTrainer):
|
||||||
|
def prediction_step(
|
||||||
|
self,
|
||||||
|
model: nn.Module,
|
||||||
|
inputs: dict[str, Any],
|
||||||
|
prediction_loss_only: bool,
|
||||||
|
ignore_keys=None,
|
||||||
|
**gen_kwargs,
|
||||||
|
) -> tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||||
|
if self.args.predict_with_generate:
|
||||||
|
output_ids = inputs.pop('output_ids')
|
||||||
|
input_ids = inputs['input_ids']
|
||||||
|
loss, generated_tokens, labels = super().prediction_step(
|
||||||
|
model, inputs, prediction_loss_only, ignore_keys, **gen_kwargs
|
||||||
|
)
|
||||||
|
generated_tokens = generated_tokens[:, input_ids.size()[1]:]
|
||||||
|
if self.args.predict_with_generate:
|
||||||
|
labels = output_ids
|
||||||
|
return loss, generated_tokens, labels
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_path(path: Union[str, Path]) -> Path:
|
||||||
|
return Path(path).expanduser().resolve()
|
||||||
|
|
||||||
|
|
||||||
|
def _sanity_check(
|
||||||
|
input_ids: Sequence[int],
|
||||||
|
output_ids: Sequence[int],
|
||||||
|
tokenizer: PreTrainedTokenizer,
|
||||||
|
):
|
||||||
|
print('--> Sanity check')
|
||||||
|
for in_id, out_id in zip(input_ids, output_ids):
|
||||||
|
if in_id == 0:
|
||||||
|
continue
|
||||||
|
if in_id in tokenizer.tokenizer.index_special_tokens:
|
||||||
|
in_text = tokenizer.tokenizer.index_special_tokens[in_id]
|
||||||
|
else:
|
||||||
|
in_text = tokenizer.decode([in_id])
|
||||||
|
print(f'{repr(in_text):>20}: {in_id} -> {out_id}')
|
||||||
|
|
||||||
|
|
||||||
|
@functools.cache
|
||||||
|
def _get_yaml_parser() -> yaml.YAML:
|
||||||
|
parser = yaml.YAML(typ='safe', pure=True)
|
||||||
|
parser.indent(mapping=2, offset=2, sequence=4)
|
||||||
|
parser.default_flow_style = False
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
@dc.dataclass
|
||||||
|
class DataConfig(object):
|
||||||
|
train_file: str
|
||||||
|
val_file: Optional[str] = None
|
||||||
|
test_file: Optional[str] = None
|
||||||
|
|
||||||
|
num_proc: Optional[int] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def data_format(self) -> str:
|
||||||
|
return Path(self.train_file).suffix
|
||||||
|
|
||||||
|
@property
|
||||||
|
def data_files(self) -> dict[NamedSplit, str]:
|
||||||
|
return {
|
||||||
|
split: data_file
|
||||||
|
for split, data_file in zip(
|
||||||
|
[Split.TRAIN, Split.VALIDATION, Split.TEST],
|
||||||
|
[self.train_file, self.val_file, self.test_file],
|
||||||
|
)
|
||||||
|
if data_file is not None
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dc.dataclass
|
||||||
|
class FinetuningConfig(object):
|
||||||
|
data_config: DataConfig
|
||||||
|
|
||||||
|
max_input_length: int
|
||||||
|
max_output_length: int
|
||||||
|
|
||||||
|
training_args: Seq2SeqTrainingArguments = dc.field(
|
||||||
|
default_factory=lambda: Seq2SeqTrainingArguments(output_dir='./output')
|
||||||
|
)
|
||||||
|
peft_config: Optional[PeftConfig] = None
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if not self.training_args.do_eval or self.data_config.val_file is None:
|
||||||
|
# skips the evaluation stage when `do_eval` or `eval_file` is not provided
|
||||||
|
self.training_args.do_eval = False
|
||||||
|
self.training_args.evaluation_strategy = 'no'
|
||||||
|
self.data_config.val_file = None
|
||||||
|
else:
|
||||||
|
self.training_args.per_device_eval_batch_size = (
|
||||||
|
self.training_args.per_device_eval_batch_size
|
||||||
|
or self.training_args.per_device_train_batch_size
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, **kwargs) -> 'FinetuningConfig':
|
||||||
|
training_args = kwargs.get('training_args', None)
|
||||||
|
if training_args is not None and not isinstance(
|
||||||
|
training_args, Seq2SeqTrainingArguments
|
||||||
|
):
|
||||||
|
gen_config = training_args.get('generation_config')
|
||||||
|
# TODO: a bit hacky
|
||||||
|
if not isinstance(gen_config, GenerationConfig):
|
||||||
|
training_args['generation_config'] = GenerationConfig(
|
||||||
|
**gen_config
|
||||||
|
)
|
||||||
|
kwargs['training_args'] = Seq2SeqTrainingArguments(**training_args)
|
||||||
|
|
||||||
|
data_config = kwargs.get('data_config')
|
||||||
|
if not isinstance(data_config, DataConfig):
|
||||||
|
kwargs['data_config'] = DataConfig(**data_config)
|
||||||
|
|
||||||
|
peft_config = kwargs.get('peft_config', None)
|
||||||
|
if peft_config is not None and not isinstance(peft_config, PeftConfig):
|
||||||
|
kwargs['peft_config'] = get_peft_config(peft_config)
|
||||||
|
return cls(**kwargs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_file(cls, path: Union[str, Path]) -> 'FinetuningConfig':
|
||||||
|
path = _resolve_path(path)
|
||||||
|
kwargs = _get_yaml_parser().load(path)
|
||||||
|
return cls.from_dict(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_datasets(
|
||||||
|
data_dir: Path,
|
||||||
|
data_format: str,
|
||||||
|
data_files: dict[NamedSplit, str],
|
||||||
|
num_proc: Optional[int],
|
||||||
|
) -> DatasetDict:
|
||||||
|
if data_format in ('.csv', '.json', '.jsonl'):
|
||||||
|
dataset_dct = load_dataset(
|
||||||
|
data_format[1:],
|
||||||
|
data_dir=data_dir,
|
||||||
|
data_files=data_files,
|
||||||
|
num_proc=num_proc,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
err_msg = f"Cannot load dataset in the '{data_format}' format."
|
||||||
|
raise NotImplementedError(err_msg)
|
||||||
|
|
||||||
|
return dataset_dct
|
||||||
|
|
||||||
|
|
||||||
|
class DataManager(object):
|
||||||
|
def __init__(self, data_dir: str, data_config: DataConfig):
|
||||||
|
self._num_proc = data_config.num_proc
|
||||||
|
|
||||||
|
self._dataset_dct = _load_datasets(
|
||||||
|
_resolve_path(data_dir),
|
||||||
|
data_config.data_format,
|
||||||
|
data_config.data_files,
|
||||||
|
self._num_proc,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_dataset(self, split: NamedSplit) -> Optional[Dataset]:
|
||||||
|
return self._dataset_dct.get(split, None)
|
||||||
|
|
||||||
|
def get_dataset(
|
||||||
|
self,
|
||||||
|
split: NamedSplit,
|
||||||
|
process_fn: Callable[[dict[str, Any]], dict[str, Any]],
|
||||||
|
batched: bool = True,
|
||||||
|
remove_orig_columns: bool = True,
|
||||||
|
) -> Optional[Dataset]:
|
||||||
|
orig_dataset = self._get_dataset(split)
|
||||||
|
if orig_dataset is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
if remove_orig_columns:
|
||||||
|
remove_columns = orig_dataset.column_names
|
||||||
|
else:
|
||||||
|
remove_columns = None
|
||||||
|
return orig_dataset.map(
|
||||||
|
process_fn,
|
||||||
|
batched=batched,
|
||||||
|
remove_columns=remove_columns,
|
||||||
|
num_proc=self._num_proc,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def print_model_size(model: PreTrainedModel):
|
||||||
|
print("--> Model")
|
||||||
|
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||||
|
print(f"\n--> model has {total_params / 1e6}M params\n")
|
||||||
|
|
||||||
|
|
||||||
|
def process_batch(
|
||||||
|
batch: Mapping[str, Sequence],
|
||||||
|
tokenizer: PreTrainedTokenizer,
|
||||||
|
max_input_length: int,
|
||||||
|
max_output_length: int,
|
||||||
|
) -> dict[str, list]:
|
||||||
|
batched_tools = batch.get('tools', None)
|
||||||
|
batched_conv = batch['conversations']
|
||||||
|
batched_input_ids = []
|
||||||
|
batched_labels = []
|
||||||
|
|
||||||
|
if batched_tools is None:
|
||||||
|
batched_tools = [None] * len(batched_conv)
|
||||||
|
|
||||||
|
for tools, conv in zip(batched_tools, batched_conv):
|
||||||
|
input_ids, loss_masks = [
|
||||||
|
tokenizer.get_command('[gMASK]'),
|
||||||
|
tokenizer.get_command('sop'),
|
||||||
|
], [False, False]
|
||||||
|
|
||||||
|
if tools is not None:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
for message in conv:
|
||||||
|
if message['role'] in ('system', 'user'):
|
||||||
|
loss_mask_val = False
|
||||||
|
else:
|
||||||
|
loss_mask_val = True
|
||||||
|
|
||||||
|
if message['role'] == 'tool':
|
||||||
|
raise NotImplementedError()
|
||||||
|
else:
|
||||||
|
new_input_ids = tokenizer.build_single_message(
|
||||||
|
message['role'], '', message['content']
|
||||||
|
)
|
||||||
|
new_loss_masks = [loss_mask_val] * len(new_input_ids)
|
||||||
|
|
||||||
|
input_ids += new_input_ids
|
||||||
|
loss_masks += new_loss_masks
|
||||||
|
|
||||||
|
input_ids.append(tokenizer.eos_token_id)
|
||||||
|
loss_masks = [False, *loss_masks]
|
||||||
|
labels = []
|
||||||
|
for input_id, mask in zip(input_ids, loss_masks):
|
||||||
|
if mask:
|
||||||
|
labels.append(input_id)
|
||||||
|
else:
|
||||||
|
labels.append(-100)
|
||||||
|
max_length = max_input_length + max_output_length + 1
|
||||||
|
batched_input_ids.append(input_ids[:max_length])
|
||||||
|
batched_labels.append(labels[:max_length])
|
||||||
|
return {'input_ids': batched_input_ids, 'labels': batched_labels}
|
||||||
|
|
||||||
|
|
||||||
|
def process_batch_eval(
|
||||||
|
batch: Mapping[str, Sequence],
|
||||||
|
tokenizer: PreTrainedTokenizer,
|
||||||
|
max_input_length: int,
|
||||||
|
max_output_length: int,
|
||||||
|
) -> dict[str, list]:
|
||||||
|
batched_tools = batch.get('tools', None)
|
||||||
|
batched_conv = batch['conversations']
|
||||||
|
batched_input_ids = []
|
||||||
|
# To avoid computing loss, we do not provide the `labels` field in the input dictionary.
|
||||||
|
batched_output_ids = []
|
||||||
|
|
||||||
|
if batched_tools is None:
|
||||||
|
batched_tools = [None] * len(batched_conv)
|
||||||
|
|
||||||
|
for tools, conv in zip(batched_tools, batched_conv):
|
||||||
|
input_ids = [
|
||||||
|
tokenizer.get_command('[gMASK]'),
|
||||||
|
tokenizer.get_command('sop'),
|
||||||
|
]
|
||||||
|
|
||||||
|
if tools is not None:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
for message in conv:
|
||||||
|
if len(input_ids) >= max_input_length:
|
||||||
|
break
|
||||||
|
if message['role'] == 'tool':
|
||||||
|
raise NotImplementedError()
|
||||||
|
else:
|
||||||
|
new_input_ids = tokenizer.build_single_message(
|
||||||
|
message['role'], '', message['content']
|
||||||
|
)
|
||||||
|
if message['role'] == 'assistant':
|
||||||
|
output_prompt, output_ids = (
|
||||||
|
new_input_ids[:1],
|
||||||
|
new_input_ids[1:],
|
||||||
|
)
|
||||||
|
output_ids.append(tokenizer.eos_token_id)
|
||||||
|
batched_input_ids.append(
|
||||||
|
input_ids[:max_input_length] + output_prompt[:1]
|
||||||
|
)
|
||||||
|
batched_output_ids.append(output_ids[:max_output_length])
|
||||||
|
input_ids += new_input_ids
|
||||||
|
return {'input_ids': batched_input_ids, 'output_ids': batched_output_ids}
|
||||||
|
|
||||||
|
|
||||||
|
# Not sure if this is necessary, can set it to half.
|
||||||
|
# If train with cpu, cast all params to fp32 instead of trainable ones.
|
||||||
|
def _prepare_model_for_training(model: nn.Module, use_cpu: bool):
|
||||||
|
for param in model.parameters():
|
||||||
|
if param.requires_grad or use_cpu:
|
||||||
|
param.data = param.data.to(torch.float32)
|
||||||
|
|
||||||
|
|
||||||
|
def load_tokenizer_and_model(
|
||||||
|
model_dir: str,
|
||||||
|
peft_config: Optional[PeftConfig] = None,
|
||||||
|
) -> tuple[PreTrainedTokenizer, nn.Module]:
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
|
||||||
|
if peft_config is not None:
|
||||||
|
if peft_config.peft_type.name == "PREFIX_TUNING":
|
||||||
|
config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True)
|
||||||
|
config.pre_seq_len = peft_config.num_virtual_tokens
|
||||||
|
config.use_cache = False
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_dir,
|
||||||
|
trust_remote_code=True,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
if peft_config.peft_type.name == "LORA":
|
||||||
|
# Add below L417 to enable accelerator to schedule model to Intel Arc XPU
|
||||||
|
os.environ["ACCELERATE_USE_XPU"] = "true"
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_dir,
|
||||||
|
trust_remote_code=True,
|
||||||
|
empty_init=False,
|
||||||
|
use_cache=False,
|
||||||
|
torch_dtype=torch.bfloat16
|
||||||
|
)
|
||||||
|
|
||||||
|
model = get_peft_model(model, peft_config)
|
||||||
|
model.print_trainable_parameters()
|
||||||
|
print(model)
|
||||||
|
else:
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_dir,
|
||||||
|
trust_remote_code=True,
|
||||||
|
empty_init=False,
|
||||||
|
use_cache=False
|
||||||
|
)
|
||||||
|
print_model_size(model)
|
||||||
|
return tokenizer, model
|
||||||
|
|
||||||
|
|
||||||
|
def compute_metrics(eval_preds: EvalPrediction, tokenizer: PreTrainedTokenizer):
|
||||||
|
batched_pred_ids, batched_label_ids = eval_preds
|
||||||
|
|
||||||
|
metrics_dct = {'rouge-1': [], 'rouge-2': [], 'rouge-l': [], 'bleu-4': []}
|
||||||
|
for pred_ids, label_ids in zip(batched_pred_ids, batched_label_ids):
|
||||||
|
pred_txt = tokenizer.decode(pred_ids).strip()
|
||||||
|
label_txt = tokenizer.decode(label_ids).strip()
|
||||||
|
pred_tokens = list(jieba.cut(pred_txt))
|
||||||
|
label_tokens = list(jieba.cut(label_txt))
|
||||||
|
rouge = Rouge()
|
||||||
|
scores = rouge.get_scores(' '.join(pred_tokens), ' '.join(label_tokens))
|
||||||
|
for k, v in scores[0].items():
|
||||||
|
metrics_dct[k].append(round(v['f'] * 100, 4))
|
||||||
|
metrics_dct['bleu-4'].append(
|
||||||
|
sentence_bleu(
|
||||||
|
[label_tokens],
|
||||||
|
pred_tokens,
|
||||||
|
smoothing_function=SmoothingFunction().method3,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return {k: np.mean(v) for k, v in metrics_dct.items()}
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def main(
|
||||||
|
data_dir: Annotated[str, typer.Argument(help='')],
|
||||||
|
model_dir: Annotated[
|
||||||
|
str,
|
||||||
|
typer.Argument(
|
||||||
|
help='A string that specifies the model id of a pretrained model configuration hosted on huggingface.co, or a path to a directory containing a model configuration file.'
|
||||||
|
),
|
||||||
|
],
|
||||||
|
config_file: Annotated[str, typer.Argument(help='')],
|
||||||
|
# Add below L474, which is path of deepspeed config file to enable finetuning on 2 Intel Arc XPU cards
|
||||||
|
deepspeed_config_file: Annotated[str, typer.Argument(default='', help='if specified, will apply data parallel')]
|
||||||
|
auto_resume_from_checkpoint: str = typer.Argument(
|
||||||
|
default='',
|
||||||
|
help='If entered as yes, automatically use the latest save checkpoint. If it is a numerical example 12 15, use the corresponding save checkpoint. If the input is no, restart training'
|
||||||
|
)
|
||||||
|
):
|
||||||
|
ft_config = FinetuningConfig.from_file(config_file)
|
||||||
|
tokenizer, model = load_tokenizer_and_model(model_dir, peft_config=ft_config.peft_config)
|
||||||
|
data_manager = DataManager(data_dir, ft_config.data_config)
|
||||||
|
|
||||||
|
train_dataset = data_manager.get_dataset(
|
||||||
|
Split.TRAIN,
|
||||||
|
functools.partial(
|
||||||
|
process_batch,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
max_input_length=ft_config.max_input_length,
|
||||||
|
max_output_length=ft_config.max_output_length,
|
||||||
|
),
|
||||||
|
batched=True,
|
||||||
|
)
|
||||||
|
print('train_dataset:', train_dataset)
|
||||||
|
val_dataset = data_manager.get_dataset(
|
||||||
|
Split.VALIDATION,
|
||||||
|
functools.partial(
|
||||||
|
process_batch_eval,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
max_input_length=ft_config.max_input_length,
|
||||||
|
max_output_length=ft_config.max_output_length,
|
||||||
|
),
|
||||||
|
batched=True,
|
||||||
|
)
|
||||||
|
if val_dataset is not None:
|
||||||
|
print('val_dataset:', val_dataset)
|
||||||
|
test_dataset = data_manager.get_dataset(
|
||||||
|
Split.TEST,
|
||||||
|
functools.partial(
|
||||||
|
process_batch_eval,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
max_input_length=ft_config.max_input_length,
|
||||||
|
max_output_length=ft_config.max_output_length,
|
||||||
|
),
|
||||||
|
batched=True,
|
||||||
|
)
|
||||||
|
if test_dataset is not None:
|
||||||
|
print('test_dataset:', test_dataset)
|
||||||
|
|
||||||
|
# checks encoded dataset
|
||||||
|
_sanity_check(
|
||||||
|
train_dataset[0]["input_ids"], train_dataset[0]["labels"], tokenizer
|
||||||
|
)
|
||||||
|
|
||||||
|
# turn model to fp32
|
||||||
|
_prepare_model_for_training(model, ft_config.training_args.use_cpu)
|
||||||
|
|
||||||
|
ft_config.training_args.generation_config.pad_token_id = (
|
||||||
|
tokenizer.pad_token_id
|
||||||
|
)
|
||||||
|
ft_config.training_args.generation_config.eos_token_id = [
|
||||||
|
tokenizer.eos_token_id,
|
||||||
|
tokenizer.get_command('<|user|>'),
|
||||||
|
tokenizer.get_command('<|observation|>'),
|
||||||
|
]
|
||||||
|
model.gradient_checkpointing_enable()
|
||||||
|
model.enable_input_require_grads()
|
||||||
|
|
||||||
|
use_tokenizer = True
|
||||||
|
if ft_config.peft_config is not None:
|
||||||
|
use_tokenizer = False if ft_config.peft_config.peft_type == "LORA" else True
|
||||||
|
|
||||||
|
# Add below L544-L546 to enable finetuning on 2 Intel Arc XPU cards on top of oneccl and deepspeed
|
||||||
|
if deepspeed_config_file is not '':
|
||||||
|
ft_config.training_args.ddp_backend = "ccl"
|
||||||
|
ft_config.training_args.deepspeed = deepspeed_config_file
|
||||||
|
|
||||||
|
trainer = Seq2SeqTrainer(
|
||||||
|
model=model,
|
||||||
|
args=ft_config.training_args,
|
||||||
|
data_collator=DataCollatorForSeq2Seq(
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
padding='longest',
|
||||||
|
return_tensors='pt',
|
||||||
|
),
|
||||||
|
train_dataset=train_dataset,
|
||||||
|
eval_dataset=val_dataset.select(list(range(50))),
|
||||||
|
tokenizer=tokenizer if use_tokenizer else None, # LORA does not need tokenizer
|
||||||
|
compute_metrics=functools.partial(compute_metrics, tokenizer=tokenizer),
|
||||||
|
)
|
||||||
|
|
||||||
|
if auto_resume_from_checkpoint.upper() == "" or auto_resume_from_checkpoint is None:
|
||||||
|
trainer.train()
|
||||||
|
else:
|
||||||
|
output_dir = ft_config.training_args.output_dir
|
||||||
|
dirlist = os.listdir(output_dir)
|
||||||
|
checkpoint_sn = 0
|
||||||
|
for checkpoint_str in dirlist:
|
||||||
|
if checkpoint_str.find("eckpoint") > 0 and checkpoint_str.find("tmp") == -1:
|
||||||
|
checkpoint = int(checkpoint_str.replace("checkpoint-", ""))
|
||||||
|
if checkpoint > checkpoint_sn:
|
||||||
|
checkpoint_sn = checkpoint
|
||||||
|
if auto_resume_from_checkpoint.upper() == "YES":
|
||||||
|
if checkpoint_sn > 0:
|
||||||
|
model.gradient_checkpointing_enable()
|
||||||
|
model.enable_input_require_grads()
|
||||||
|
checkpoint_directory = os.path.join(output_dir, "checkpoint-" + str(checkpoint_sn))
|
||||||
|
print("resume checkpoint from checkpoint-" + str(checkpoint_sn))
|
||||||
|
trainer.train(resume_from_checkpoint=checkpoint_directory)
|
||||||
|
else:
|
||||||
|
trainer.train()
|
||||||
|
else:
|
||||||
|
if auto_resume_from_checkpoint.isdigit():
|
||||||
|
if int(auto_resume_from_checkpoint) > 0:
|
||||||
|
checkpoint_sn = int(auto_resume_from_checkpoint)
|
||||||
|
model.gradient_checkpointing_enable()
|
||||||
|
model.enable_input_require_grads()
|
||||||
|
checkpoint_directory = os.path.join(output_dir, "checkpoint-" + str(checkpoint_sn))
|
||||||
|
print("resume checkpoint from checkpoint-" + str(checkpoint_sn))
|
||||||
|
trainer.train(resume_from_checkpoint=checkpoint_directory)
|
||||||
|
else:
|
||||||
|
print(auto_resume_from_checkpoint,
|
||||||
|
"The specified checkpoint sn(" +
|
||||||
|
auto_resume_from_checkpoint +
|
||||||
|
") has not been saved. Please search for the correct chkeckpoint in the model output directory")
|
||||||
|
|
||||||
|
# test stage
|
||||||
|
if test_dataset is not None:
|
||||||
|
trainer.predict(test_dataset)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
app()
|
||||||
|
|
@ -0,0 +1,23 @@
|
||||||
|
#
|
||||||
|
# Copyright 2016 The BigDL Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
export BIGDL_CHECK_DUPLICATE_IMPORT=0
|
||||||
|
|
||||||
|
# You can also set the remote model repository to a local model path
|
||||||
|
python lora_finetune_chatglm.py \
|
||||||
|
./AdvertiseGen_fix \
|
||||||
|
THUDM/chatglm3-6b \
|
||||||
|
./lora_config.yaml
|
||||||
|
|
@ -0,0 +1,29 @@
|
||||||
|
#
|
||||||
|
# Copyright 2016 The BigDL Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
export MASTER_ADDR=127.0.0.1
|
||||||
|
export OMP_NUM_THREADS=6
|
||||||
|
export FI_PROVIDER=tcp
|
||||||
|
export CCL_ATL_TRANSPORT=ofi
|
||||||
|
export BIGDL_CHECK_DUPLICATE_IMPORT=0
|
||||||
|
|
||||||
|
# You can also set the remote model repository to a local model path
|
||||||
|
mpirun -n 2 \
|
||||||
|
python lora_finetune_chatglm.py \
|
||||||
|
./AdvertiseGen_fix \
|
||||||
|
THUDM/chatglm3-6b \
|
||||||
|
./lora_config.yaml \
|
||||||
|
./deepspeed_config.json
|
||||||
|
|
@ -0,0 +1,59 @@
|
||||||
|
#
|
||||||
|
# Copyright 2016 The BigDL Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
# This is ported from https://github.com/THUDM/ChatGLM3/blob/main/finetune_demo/lora_finetune.ipynb
|
||||||
|
# L60 is changed to enable users to finish all operations under one working directory
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import Union
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_path(path: Union[str, Path]) -> Path:
|
||||||
|
return Path(path).expanduser().resolve()
|
||||||
|
|
||||||
|
|
||||||
|
def _mkdir(dir_name: Union[str, Path]):
|
||||||
|
dir_name = _resolve_path(dir_name)
|
||||||
|
if not dir_name.is_dir():
|
||||||
|
dir_name.mkdir(parents=True, exist_ok=False)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_adgen(data_dir: Union[str, Path], save_dir: Union[str, Path]):
|
||||||
|
def _convert(in_file: Path, out_file: Path):
|
||||||
|
_mkdir(out_file.parent)
|
||||||
|
with open(in_file, encoding='utf-8') as fin:
|
||||||
|
with open(out_file, 'wt', encoding='utf-8') as fout:
|
||||||
|
for line in fin:
|
||||||
|
dct = json.loads(line)
|
||||||
|
sample = {'conversations': [{'role': 'user', 'content': dct['content']},
|
||||||
|
{'role': 'assistant', 'content': dct['summary']}]}
|
||||||
|
fout.write(json.dumps(sample, ensure_ascii=False) + '\n')
|
||||||
|
|
||||||
|
data_dir = _resolve_path(data_dir)
|
||||||
|
save_dir = _resolve_path(save_dir)
|
||||||
|
|
||||||
|
train_file = data_dir / 'train.json'
|
||||||
|
if train_file.is_file():
|
||||||
|
out_file = save_dir / train_file.relative_to(data_dir)
|
||||||
|
_convert(train_file, out_file)
|
||||||
|
|
||||||
|
dev_file = data_dir / 'dev.json'
|
||||||
|
if dev_file.is_file():
|
||||||
|
out_file = save_dir / dev_file.relative_to(data_dir)
|
||||||
|
_convert(dev_file, out_file)
|
||||||
|
|
||||||
|
|
||||||
|
convert_adgen('./AdvertiseGen', './AdvertiseGen_fix')
|
||||||
|
|
@ -17,7 +17,7 @@ This folder contains examples of running different training mode with IPEX-LLM o
|
||||||
|------------|-----------------------------------------------------------------|-----------------------------------------------------------------|
|
|------------|-----------------------------------------------------------------|-----------------------------------------------------------------|
|
||||||
| LLaMA 2/3 | [LoRA](LoRA), [QLoRA](QLoRA), [QA-LoRA](QA-LoRA), [ReLora](ReLora) | [HF-PEFT](HF-PEFT), [axolotl](axolotl) |
|
| LLaMA 2/3 | [LoRA](LoRA), [QLoRA](QLoRA), [QA-LoRA](QA-LoRA), [ReLora](ReLora) | [HF-PEFT](HF-PEFT), [axolotl](axolotl) |
|
||||||
| Mistral | [LoRA](DPO), [QLoRA](DPO) | [DPO](DPO) |
|
| Mistral | [LoRA](DPO), [QLoRA](DPO) | [DPO](DPO) |
|
||||||
| ChatGLM 3 | [QLoRA](QLoRA/alpaca-qlora#3-qlora-finetune) | HF-PEFT |
|
| ChatGLM 3 | [LoRA](LoRA/chatglm_finetune#lora-fine-tuning-on-chatglm3-6b-with-ipex-llm), [QLoRA](QLoRA/alpaca-qlora#3-qlora-finetune) | HF-PEFT |
|
||||||
| Qwen-1.5 | [QLoRA](QLoRA/alpaca-qlora#3-qlora-finetune) | HF-PEFT |
|
| Qwen-1.5 | [QLoRA](QLoRA/alpaca-qlora#3-qlora-finetune) | HF-PEFT |
|
||||||
| Baichuan2 | [QLoRA](QLoRA/alpaca-qlora#3-qlora-finetune) | HF-PEFT |
|
| Baichuan2 | [QLoRA](QLoRA/alpaca-qlora#3-qlora-finetune) | HF-PEFT |
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue