LISA Finetuning Example (#10743)
* enabling xetla only supports qtype=SYM_INT4 or FP8E5 * LISA Finetuning Example on gpu * update readme * add licence * Explain parameters of lisa & Move backend codes to src dir * fix style * fix style * update readme * support chatglm * fix style * fix style * update readme * fix
This commit is contained in:
parent
581ebf6104
commit
ff040c8f01
4 changed files with 327 additions and 1 deletions
75
python/llm/example/GPU/LLM-Finetuning/LISA/README.md
Normal file
75
python/llm/example/GPU/LLM-Finetuning/LISA/README.md
Normal file
|
|
@ -0,0 +1,75 @@
|
|||
# LISA Finetuning with IPEX-LLM
|
||||
|
||||
This example refers to [LISA with LMFLow's DynamicLayerActivationCallback Class](https://github.com/OptimalScale/LMFlow/blob/f3b3b007ea526009172c355e9d52ffa146b9dc0c/src/lmflow/pipeline/finetuner.py#L301), and adds [LISA fintuning](https://arxiv.org/abs/2403.17919) to IPEX-LLM on [Intel GPU](../../../GPU/README.md), based on [LORA finetuning with IPEX-LLM](../LoRA/alpaca_lora_finetuning.py).
|
||||
|
||||
### 0. Requirements
|
||||
|
||||
To run this example with IPEX-LLM on Intel GPUs, we have some recommended requirements for your machine, please refer to [here](../../../GPU/README.md#requirements) for more information.
|
||||
|
||||
### 1. Install
|
||||
|
||||
```bash
|
||||
conda create -n llm python=3.11
|
||||
conda activate llm
|
||||
# below command will install intel_extension_for_pytorch==2.1.10+xpu as default
|
||||
pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
|
||||
pip install accelerate==0.23.0
|
||||
pip install bitsandbytes==0.43.0
|
||||
pip install datasets==2.18.0
|
||||
pip install --upgrade transformers==4.36.0
|
||||
pip install scipy fire
|
||||
```
|
||||
|
||||
### 2. LISA Finetune
|
||||
|
||||
```bash
|
||||
# Configures OneAPI environment variables
|
||||
source /opt/intel/oneapi/setvars.sh
|
||||
python ./lisa_finetuning.py \
|
||||
--micro_batch_size 8 \
|
||||
--batch_size 128 \
|
||||
--base_model "meta-llama/Llama-2-7b-hf" \
|
||||
--data_path "yahma/alpaca-cleaned" \
|
||||
--output_dir "./ipex-llm-lisa-alpaca" \
|
||||
--gradient_checkpointing True \
|
||||
--lisa_activated_layers 1 \
|
||||
--lisa_interval_steps 20
|
||||
```
|
||||
|
||||
Optional parameters for `lisa_finetuning.py`:
|
||||
|
||||
**--repo-id-or-model-path** : default to `meta-llama/Llama-2-7b-hf`, and you can also specify your local model path.
|
||||
|
||||
**--data-path** : default to `yahma/alpaca-cleaned`, and you can also specify your local datal path, while note that changing to the other datasets will introduce code modification effort for yourself.
|
||||
|
||||
**--output-dir** : default to `./ipex-llm-lisa-alpaca` to save fine-tuned model, and you can change if needed.
|
||||
|
||||
**--lisa_activated_layers** : the number of self-attention layers randomly selected to activate.
|
||||
|
||||
**lisa_interval_steps** : the number of interval steps to switch active layers.
|
||||
|
||||
### 3. Sample Output
|
||||
|
||||
```log
|
||||
......
|
||||
{'loss': 1.8391, 'learning_rate': 1.9967238104745695e-05, 'epoch': 0.03}
|
||||
{'loss': 1.8242, 'learning_rate': 1.9869167087338908e-05, 'epoch': 0.05}
|
||||
5%|██████▉ | 20/388 [xx:xx<x:xx:xx, x.xxs/it]
|
||||
Activating layers at indices: [10] for the next steps.
|
||||
{'loss': 1.8128, 'learning_rate': 1.9706429546259592e-05, 'epoch': 0.08}
|
||||
{'loss': 1.775, 'learning_rate': 1.9480091799562706e-05, 'epoch': 0.1}
|
||||
10%|██████████████ | 40/388 [xx:xx<xx:xx, x.xxs/it]
|
||||
Activating layers at indices: [30] for the next steps.
|
||||
{'loss': 1.7669, 'learning_rate': 1.9191636897958123e-05, 'epoch': 0.13}
|
||||
{'loss': 1.7749, 'learning_rate': 1.8842954907300236e-05, 'epoch': 0.15}
|
||||
15%|█████████████████████ | 60/388 [xx:xx<xx:xx, x.xxs/it]
|
||||
Activating layers at indices: [26] for the next steps.
|
||||
{'loss': 1.7735, 'learning_rate': 1.8436330524160048e-05, 'epoch': 0.18}
|
||||
{'loss': 1.7199, 'learning_rate': 1.797442810562721e-05, 'epoch': 0.21}
|
||||
21%|████████████████████████████ | 80/388 [xx:xx<xx:xx, x.xxs/it]
|
||||
Activating layers at indices: [17] for the next steps.
|
||||
{'loss': 1.7328, 'learning_rate': 1.7460274211432463e-05, 'epoch': 0.23}
|
||||
25%|█████████████████████████████████▋ | 96/388 [xx:xx<xx:xx, x.xxs/it]
|
||||
......
|
||||
|
||||
```
|
||||
167
python/llm/example/GPU/LLM-Finetuning/LISA/lisa_finetuning.py
Normal file
167
python/llm/example/GPU/LLM-Finetuning/LISA/lisa_finetuning.py
Normal file
|
|
@ -0,0 +1,167 @@
|
|||
#
|
||||
# Copyright 2016 The BigDL Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import os
|
||||
from typing import List
|
||||
import fire
|
||||
import torch
|
||||
import transformers
|
||||
from datasets import load_dataset
|
||||
import accelerate
|
||||
from ipex_llm.transformers.lisa import DynamicLayerActivationCallback
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
current_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
common_util_path = os.path.join(current_dir, '..')
|
||||
import sys
|
||||
sys.path.append(common_util_path)
|
||||
from common.utils import Prompter, get_train_val_data
|
||||
|
||||
from ipex_llm.transformers import AutoModelForCausalLM
|
||||
from ipex_llm.utils.common import invalidInputError
|
||||
|
||||
def train(
|
||||
# model/data params
|
||||
base_model: str = "meta-llama/Llama-2-7b-hf", # the only required argument, default to be "meta-llama/Llama-2-7b-hf"
|
||||
data_path: str = "yahma/alpaca-cleaned",
|
||||
output_dir: str = "./ipex-llm-lisa-alpaca",
|
||||
# training hyperparams
|
||||
bf16: bool = True, # default to bf16
|
||||
batch_size: int = 128,
|
||||
micro_batch_size: int = 8, # default to be 8, limited by GPU memory
|
||||
num_epochs: int = 1,
|
||||
learning_rate: float = 2e-5,
|
||||
cutoff_len: int = 256,
|
||||
val_set_size: int = 2000,
|
||||
# llm hyperparams
|
||||
train_on_inputs: bool = True, # if False, masks out inputs in loss
|
||||
add_eos_token: bool = False,
|
||||
group_by_length: bool = False, # faster, but produces an odd training loss curve
|
||||
resume_from_checkpoint: str = None, # either training checkpoint or final adapter
|
||||
prompt_template_name: str = "alpaca", # The prompt template to use, will default to alpaca.
|
||||
gradient_checkpointing: bool = False,
|
||||
deepspeed: str = None,
|
||||
training_mode: str = "lisa",
|
||||
lisa_activated_layers: int = 1,
|
||||
lisa_interval_steps: int = 20,
|
||||
):
|
||||
invalidInputError(training_mode == "lisa",
|
||||
f"This example is for lisa training mode, but got training_mode={training_mode}.")
|
||||
print(
|
||||
f"Training Alpaca-LoRA model with params:\n"
|
||||
f"base_model: {base_model}\n"
|
||||
f"data_path: {data_path}\n"
|
||||
f"output_dir: {output_dir}\n"
|
||||
f"batch_size: {batch_size}\n"
|
||||
f"micro_batch_size: {micro_batch_size}\n"
|
||||
f"num_epochs: {num_epochs}\n"
|
||||
f"learning_rate: {learning_rate}\n"
|
||||
f"cutoff_len: {cutoff_len}\n"
|
||||
f"val_set_size: {val_set_size}\n"
|
||||
f"train_on_inputs: {train_on_inputs}\n"
|
||||
f"add_eos_token: {add_eos_token}\n"
|
||||
f"group_by_length: {group_by_length}\n"
|
||||
f"resume_from_checkpoint: {resume_from_checkpoint or False}\n"
|
||||
f"prompt template: {prompt_template_name}\n"
|
||||
f"training_mode: {training_mode}\n"
|
||||
f"lisa_activated_layers: {lisa_activated_layers}\n"
|
||||
f"lisa_interval_steps: {lisa_interval_steps}\n"
|
||||
)
|
||||
assert (
|
||||
base_model
|
||||
), "Please specify a --base_model, e.g. --base_model='huggyllama/llama-7b'"
|
||||
gradient_accumulation_steps = batch_size // micro_batch_size
|
||||
|
||||
prompter = Prompter(prompt_template_name)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
base_model,
|
||||
load_in_low_bit="bf16",
|
||||
optimize_model=True,
|
||||
torch_dtype=torch.bfloat16,
|
||||
trust_remote_code=True,
|
||||
enable_xetla=False
|
||||
)
|
||||
|
||||
model = model.to("xpu")
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
|
||||
|
||||
tokenizer.pad_token_id = (
|
||||
0 # unk. we want this to be different from the eos token
|
||||
)
|
||||
tokenizer.padding_side = "left" # Allow batched inference
|
||||
|
||||
print(model)
|
||||
|
||||
if data_path.endswith(".json") or data_path.endswith(".jsonl"):
|
||||
data = load_dataset("json", data_files=data_path)
|
||||
else:
|
||||
data = load_dataset(data_path)
|
||||
|
||||
train_data, val_data = get_train_val_data(data, tokenizer, prompter, train_on_inputs,
|
||||
add_eos_token, cutoff_len, val_set_size, seed=42)
|
||||
|
||||
trainer_callbacks = []
|
||||
|
||||
# Instantiate the callback
|
||||
dynamic_layer_activation_callback = DynamicLayerActivationCallback(
|
||||
n_layers=lisa_activated_layers, # Number of layers to activate
|
||||
interval_steps = lisa_interval_steps, # Step interval to update active layers
|
||||
model = model
|
||||
)
|
||||
trainer_callbacks.append(dynamic_layer_activation_callback)
|
||||
|
||||
trainer = transformers.Trainer(
|
||||
model=model,
|
||||
train_dataset=train_data,
|
||||
eval_dataset=val_data,
|
||||
args=transformers.TrainingArguments(
|
||||
per_device_train_batch_size=micro_batch_size,
|
||||
gradient_accumulation_steps=gradient_accumulation_steps,
|
||||
max_grad_norm=0.3,
|
||||
num_train_epochs=num_epochs,
|
||||
learning_rate=learning_rate,
|
||||
lr_scheduler_type="cosine",
|
||||
bf16=bf16, # ensure training more stable
|
||||
logging_steps=10,
|
||||
optim="adamw_hf",
|
||||
evaluation_strategy="steps" if val_set_size > 0 else "no",
|
||||
save_strategy="steps",
|
||||
eval_steps=200 if val_set_size > 0 else None,
|
||||
save_steps=200,
|
||||
output_dir=output_dir,
|
||||
load_best_model_at_end=True if val_set_size > 0 else False,
|
||||
group_by_length=group_by_length,
|
||||
gradient_checkpointing=gradient_checkpointing,
|
||||
deepspeed=deepspeed,
|
||||
save_safetensors=False,
|
||||
),
|
||||
data_collator=transformers.DataCollatorForSeq2Seq(
|
||||
tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
|
||||
),
|
||||
callbacks=trainer_callbacks
|
||||
)
|
||||
model.config.use_cache = False
|
||||
|
||||
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
||||
|
||||
# model.save_pretrained(output_dir)
|
||||
trainer.save_model()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(train)
|
||||
81
python/llm/src/ipex_llm/transformers/lisa.py
Normal file
81
python/llm/src/ipex_llm/transformers/lisa.py
Normal file
|
|
@ -0,0 +1,81 @@
|
|||
#
|
||||
# Copyright 2016 The BigDL Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from transformers import TrainerCallback
|
||||
import numpy as np
|
||||
from ipex_llm.utils.common import invalidInputError
|
||||
|
||||
|
||||
# source: https://github.com/OptimalScale/LMFlow/blob/main/src/lmflow/pipeline/finetuner.py
|
||||
class DynamicLayerActivationCallback(TrainerCallback):
|
||||
def __init__(self, n_layers, interval_steps, model):
|
||||
super().__init__()
|
||||
self.n_layers = n_layers
|
||||
self.interval_steps = interval_steps
|
||||
self.model = model
|
||||
|
||||
# Determine the way to access layers based on the model type
|
||||
class_to_layers_map = {
|
||||
'LlamaForCausalLM': 'model.model.layers',
|
||||
'Qwen2ForCausalLM': 'model.model.layers',
|
||||
'MistralForCausalLM': 'model.model.layers',
|
||||
'MixtralForCausalLM': 'model.model.layers',
|
||||
'GemmaForCausalLM': 'model.model.layers',
|
||||
'GPT2LMHeadModel': 'model.transformer.h',
|
||||
'ChatGLMModel': 'model.transformer.encoder.layers',
|
||||
}
|
||||
model_class_name = self.model.__class__.__name__
|
||||
if model_class_name in class_to_layers_map:
|
||||
self.layers_attribute = class_to_layers_map[model_class_name]
|
||||
else:
|
||||
# self.layers_attribute = training_args.lisa_layers_attribute
|
||||
invalidInputError(False, f"Model {model_class_name} not supported.")
|
||||
# Dynamically execute to get the number of layers
|
||||
self.total_layers = len(eval('self.' + self.layers_attribute))
|
||||
|
||||
self.active_layers_indices = []
|
||||
|
||||
def freeze_all_layers(self):
|
||||
layers = eval('self.' + self.layers_attribute) # Dynamically execute to get layers
|
||||
for layer in layers:
|
||||
for param in layer.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def on_step_begin(self, args, state, control, **kwargs):
|
||||
# Check if it's time to switch active layers, including at step 0
|
||||
if state.global_step % self.interval_steps == 0:
|
||||
self.switch_active_layers()
|
||||
|
||||
def switch_active_layers(self):
|
||||
# First, disable gradients for all layers
|
||||
self.freeze_all_layers()
|
||||
|
||||
# Randomly select n_layers to activate
|
||||
layers = eval('self.' + self.layers_attribute) # Re-fetch layer references
|
||||
self.active_layers_indices = np.random.choice(
|
||||
range(self.total_layers),
|
||||
self.n_layers,
|
||||
replace=False
|
||||
)
|
||||
print(
|
||||
f"Activating layers at indices: {self.active_layers_indices} for the next steps.",
|
||||
flush=True
|
||||
)
|
||||
|
||||
# Enable gradients only for the selected layers
|
||||
for idx in self.active_layers_indices:
|
||||
for param in layers[idx].parameters():
|
||||
param.requires_grad = True
|
||||
|
|
@ -313,7 +313,10 @@ def should_use_xetla_mm_qkv(self, device):
|
|||
full_attn = self.q_proj.out_len == self.k_proj.out_len == self.v_proj.out_len
|
||||
supported_qtype = self.q_proj.qtype == SYM_INT4 and full_attn
|
||||
supported_qtype = supported_qtype or self.q_proj.qtype == FP8E5
|
||||
enable_xetla = self.q_proj.enable_xetla
|
||||
if self.q_proj.qtype == SYM_INT4 or self.q_proj.qtype == FP8E5:
|
||||
enable_xetla = self.q_proj.enable_xetla
|
||||
else:
|
||||
enable_xetla = False
|
||||
return device.type == "xpu" and enable_xetla and supported_qtype
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue