diff --git a/python/llm/example/GPU/LLM-Finetuning/LISA/README.md b/python/llm/example/GPU/LLM-Finetuning/LISA/README.md new file mode 100644 index 00000000..dfc9f1b3 --- /dev/null +++ b/python/llm/example/GPU/LLM-Finetuning/LISA/README.md @@ -0,0 +1,75 @@ +# LISA Finetuning with IPEX-LLM + +This example refers to [LISA with LMFLow's DynamicLayerActivationCallback Class](https://github.com/OptimalScale/LMFlow/blob/f3b3b007ea526009172c355e9d52ffa146b9dc0c/src/lmflow/pipeline/finetuner.py#L301), and adds [LISA fintuning](https://arxiv.org/abs/2403.17919) to IPEX-LLM on [Intel GPU](../../../GPU/README.md), based on [LORA finetuning with IPEX-LLM](../LoRA/alpaca_lora_finetuning.py). + +### 0. Requirements + +To run this example with IPEX-LLM on Intel GPUs, we have some recommended requirements for your machine, please refer to [here](../../../GPU/README.md#requirements) for more information. + +### 1. Install + +```bash +conda create -n llm python=3.11 +conda activate llm +# below command will install intel_extension_for_pytorch==2.1.10+xpu as default +pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ +pip install accelerate==0.23.0 +pip install bitsandbytes==0.43.0 +pip install datasets==2.18.0 +pip install --upgrade transformers==4.36.0 +pip install scipy fire +``` + +### 2. LISA Finetune + +```bash +# Configures OneAPI environment variables +source /opt/intel/oneapi/setvars.sh +python ./lisa_finetuning.py \ + --micro_batch_size 8 \ + --batch_size 128 \ + --base_model "meta-llama/Llama-2-7b-hf" \ + --data_path "yahma/alpaca-cleaned" \ + --output_dir "./ipex-llm-lisa-alpaca" \ + --gradient_checkpointing True \ + --lisa_activated_layers 1 \ + --lisa_interval_steps 20 +``` + +Optional parameters for `lisa_finetuning.py`: + +**--repo-id-or-model-path** : default to `meta-llama/Llama-2-7b-hf`, and you can also specify your local model path. + +**--data-path** : default to `yahma/alpaca-cleaned`, and you can also specify your local datal path, while note that changing to the other datasets will introduce code modification effort for yourself. + +**--output-dir** : default to `./ipex-llm-lisa-alpaca` to save fine-tuned model, and you can change if needed. + +**--lisa_activated_layers** : the number of self-attention layers randomly selected to activate. + +**lisa_interval_steps** : the number of interval steps to switch active layers. + +### 3. Sample Output + +```log +...... +{'loss': 1.8391, 'learning_rate': 1.9967238104745695e-05, 'epoch': 0.03} +{'loss': 1.8242, 'learning_rate': 1.9869167087338908e-05, 'epoch': 0.05} + 5%|██████▉ | 20/388 [xx:xx 0 else "no", + save_strategy="steps", + eval_steps=200 if val_set_size > 0 else None, + save_steps=200, + output_dir=output_dir, + load_best_model_at_end=True if val_set_size > 0 else False, + group_by_length=group_by_length, + gradient_checkpointing=gradient_checkpointing, + deepspeed=deepspeed, + save_safetensors=False, + ), + data_collator=transformers.DataCollatorForSeq2Seq( + tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True + ), + callbacks=trainer_callbacks + ) + model.config.use_cache = False + + trainer.train(resume_from_checkpoint=resume_from_checkpoint) + + # model.save_pretrained(output_dir) + trainer.save_model() + + +if __name__ == "__main__": + fire.Fire(train) diff --git a/python/llm/src/ipex_llm/transformers/lisa.py b/python/llm/src/ipex_llm/transformers/lisa.py new file mode 100644 index 00000000..83ca64e4 --- /dev/null +++ b/python/llm/src/ipex_llm/transformers/lisa.py @@ -0,0 +1,81 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from transformers import TrainerCallback +import numpy as np +from ipex_llm.utils.common import invalidInputError + + +# source: https://github.com/OptimalScale/LMFlow/blob/main/src/lmflow/pipeline/finetuner.py +class DynamicLayerActivationCallback(TrainerCallback): + def __init__(self, n_layers, interval_steps, model): + super().__init__() + self.n_layers = n_layers + self.interval_steps = interval_steps + self.model = model + + # Determine the way to access layers based on the model type + class_to_layers_map = { + 'LlamaForCausalLM': 'model.model.layers', + 'Qwen2ForCausalLM': 'model.model.layers', + 'MistralForCausalLM': 'model.model.layers', + 'MixtralForCausalLM': 'model.model.layers', + 'GemmaForCausalLM': 'model.model.layers', + 'GPT2LMHeadModel': 'model.transformer.h', + 'ChatGLMModel': 'model.transformer.encoder.layers', + } + model_class_name = self.model.__class__.__name__ + if model_class_name in class_to_layers_map: + self.layers_attribute = class_to_layers_map[model_class_name] + else: + # self.layers_attribute = training_args.lisa_layers_attribute + invalidInputError(False, f"Model {model_class_name} not supported.") + # Dynamically execute to get the number of layers + self.total_layers = len(eval('self.' + self.layers_attribute)) + + self.active_layers_indices = [] + + def freeze_all_layers(self): + layers = eval('self.' + self.layers_attribute) # Dynamically execute to get layers + for layer in layers: + for param in layer.parameters(): + param.requires_grad = False + + def on_step_begin(self, args, state, control, **kwargs): + # Check if it's time to switch active layers, including at step 0 + if state.global_step % self.interval_steps == 0: + self.switch_active_layers() + + def switch_active_layers(self): + # First, disable gradients for all layers + self.freeze_all_layers() + + # Randomly select n_layers to activate + layers = eval('self.' + self.layers_attribute) # Re-fetch layer references + self.active_layers_indices = np.random.choice( + range(self.total_layers), + self.n_layers, + replace=False + ) + print( + f"Activating layers at indices: {self.active_layers_indices} for the next steps.", + flush=True + ) + + # Enable gradients only for the selected layers + for idx in self.active_layers_indices: + for param in layers[idx].parameters(): + param.requires_grad = True diff --git a/python/llm/src/ipex_llm/transformers/models/llama.py b/python/llm/src/ipex_llm/transformers/models/llama.py index a98f02a4..8eda2212 100644 --- a/python/llm/src/ipex_llm/transformers/models/llama.py +++ b/python/llm/src/ipex_llm/transformers/models/llama.py @@ -313,7 +313,10 @@ def should_use_xetla_mm_qkv(self, device): full_attn = self.q_proj.out_len == self.k_proj.out_len == self.v_proj.out_len supported_qtype = self.q_proj.qtype == SYM_INT4 and full_attn supported_qtype = supported_qtype or self.q_proj.qtype == FP8E5 - enable_xetla = self.q_proj.enable_xetla + if self.q_proj.qtype == SYM_INT4 or self.q_proj.qtype == FP8E5: + enable_xetla = self.q_proj.enable_xetla + else: + enable_xetla = False return device.type == "xpu" and enable_xetla and supported_qtype