QALora example (#9551)

* Support qa-lora

* init

* update

* update

* update

* update

* update

* update merge

* update

* fix style & update scripts

* update

* address comments

* fix typo

* fix typo

---------

Co-authored-by: Yang Wang <yang3.wang@intel.com>
This commit is contained in:
Yina Chen 2023-12-06 15:36:21 +08:00 committed by GitHub
parent 6978b2c316
commit 404e101ded
12 changed files with 281 additions and 54 deletions

View file

@ -1,6 +1,6 @@
# Alpaca QLoRA Finetuning (experimental support) # Alpaca QLoRA & QA-LoRA Finetuning (experimental support)
This example ports [Alpaca-LoRA](https://github.com/tloen/alpaca-lora/tree/main) to BigDL-LLM QLoRA on [Intel GPUs](../../README.md). This example ports [Alpaca-LoRA](https://github.com/tloen/alpaca-lora/tree/main) to BigDL-LLM (using either [QLoRA](https://arxiv.org/abs/2305.14314) or [QA-LoRA](https://arxiv.org/abs/2309.14717) algorithm) on [Intel GPU](../../README.md).
### 0. Requirements ### 0. Requirements
To run this example with BigDL-LLM on Intel GPUs, we have some recommended requirements for your machine, please refer to [here](../../README.md#requirements) for more information. To run this example with BigDL-LLM on Intel GPUs, we have some recommended requirements for your machine, please refer to [here](../../README.md#requirements) for more information.
@ -28,54 +28,75 @@ source /opt/intel/oneapi/setvars.sh
Here, we provide example usages on different hardware. Please refer to the appropriate script based on your device: Here, we provide example usages on different hardware. Please refer to the appropriate script based on your device:
#### Finetuning LLaMA2-7B on single Arc A770 #### QLoRA
##### Finetuning LLaMA2-7B on single Arc A770
```bash ```bash
bash finetune_llama2_7b_arc_1_card.sh bash finetune_llama2_7b_arc_1_card.sh
``` ```
#### Finetuning LLaMA2-7B on two Arc A770 ##### Finetuning LLaMA2-7B on two Arc A770
```bash ```bash
bash finetune_llama2_7b_arc_2_card.sh bash finetune_llama2_7b_arc_2_card.sh
``` ```
#### Finetuning LLaMA2-7B on single Data Center GPU Flex 170 ##### Finetuning LLaMA2-7B on single Data Center GPU Flex 170
```bash ```bash
bash finetune_llama2_7b_flex_170_1_card.sh bash finetune_llama2_7b_flex_170_1_card.sh
``` ```
#### Finetuning LLaMA2-7B on three Data Center GPU Flex 170 ##### Finetuning LLaMA2-7B on three Data Center GPU Flex 170
```bash ```bash
bash finetune_llama2_7b_flex_170_3_card.sh bash finetune_llama2_7b_flex_170_3_card.sh
``` ```
#### Finetuning LLaMA2-7B on single Intel Data Center GPU Max 1100 ##### Finetuning LLaMA2-7B on single Intel Data Center GPU Max 1100
```bash ```bash
bash finetune_llama2_7b_pvc_1100_1_card.sh bash finetune_llama2_7b_pvc_1100_1_card.sh
``` ```
#### Finetuning LLaMA2-7B on four Intel Data Center GPU Max 1100 ##### Finetuning LLaMA2-7B on four Intel Data Center GPU Max 1100
```bash ```bash
bash finetune_llama2_7b_pvc_1100_4_card.sh bash finetune_llama2_7b_pvc_1100_4_card.sh
``` ```
#### Finetuning LLaMA2-7B on single Intel Data Center GPU Max 1550 ##### Finetuning LLaMA2-7B on single Intel Data Center GPU Max 1550
```bash ```bash
bash finetune_llama2_7b_pvc_1550_1_card.sh bash finetune_llama2_7b_pvc_1550_1_card.sh
``` ```
#### Finetuning LLaMA2-7B on four Intel Data Center GPU Max 1550 ##### Finetuning LLaMA2-7B on four Intel Data Center GPU Max 1550
```bash ```bash
bash finetune_llama2_7b_pvc_1550_4_card.sh bash finetune_llama2_7b_pvc_1550_4_card.sh
``` ```
#### QA-LoRA
##### Finetuning LLaMA2-7B on single Arc A770
```bash
bash qalora_finetune_llama2_7b_arc_1_card.sh
```
##### Finetuning LLaMA2-7B on two Arc A770
```bash
bash qalora_finetune_llama2_7b_arc_2_card.sh
```
##### Finetuning LLaMA2-7B on single Tile Intel Data Center GPU Max 1550
```bash
bash qalora_finetune_llama2_7b_pvc_1550_1_tile.sh
```
**Important: If you fail to complete the whole finetuning process, it is suggested to resume training from a previously saved checkpoint by specifying `resume_from_checkpoint` to the local checkpoint folder as following:** **Important: If you fail to complete the whole finetuning process, it is suggested to resume training from a previously saved checkpoint by specifying `resume_from_checkpoint` to the local checkpoint folder as following:**
```bash ```bash
python ./alpaca_qlora_finetuning.py \ python ./alpaca_qlora_finetuning.py \
@ -97,3 +118,10 @@ python ./alpaca_qlora_finetuning.py \
{'loss': 1.8552, 'learning_rate': 2.9996503623845395e-05, 'epoch': 0.02} {'loss': 1.8552, 'learning_rate': 2.9996503623845395e-05, 'epoch': 0.02}
1%|█ | 8/1164 [xx:xx<xx:xx:xx, xx s/it] 1%|█ | 8/1164 [xx:xx<xx:xx:xx, xx s/it]
``` ```
### 4. Merge the adapter into the original model
```
python ./export_merged_model.py --repo-id-or-model-path REPO_ID_OR_MODEL_PATH --adapter_path ./outputs/checkpoint-200 --output_path ./outputs/checkpoint-200-merged
```
Then you can use `./outputs/checkpoint-200-merged` as a normal huggingface transformer model to do inference.

View file

@ -41,7 +41,6 @@ import accelerate
from transformers import LlamaTokenizer from transformers import LlamaTokenizer
from peft import ( from peft import (
LoraConfig,
get_peft_model_state_dict, get_peft_model_state_dict,
set_peft_model_state_dict, set_peft_model_state_dict,
) )
@ -51,7 +50,8 @@ import intel_extension_for_pytorch as ipex
from bigdl.llm.transformers import AutoModelForCausalLM from bigdl.llm.transformers import AutoModelForCausalLM
# import them from bigdl.llm.transformers.qlora to get a BigDL-LLM compatible Peft model # import them from bigdl.llm.transformers.qlora to get a BigDL-LLM compatible Peft model
from bigdl.llm.transformers.qlora import get_peft_model, prepare_model_for_kbit_training from bigdl.llm.transformers.qlora import get_peft_model, prepare_model_for_kbit_training,\
cast_lora_weight, LoraConfig
def get_int_from_env(env_keys, default): def get_int_from_env(env_keys, default):
"""Returns the first positive env value found in the `env_keys` list or the default.""" """Returns the first positive env value found in the `env_keys` list or the default."""
@ -109,6 +109,7 @@ def train(
prompt_template_name: str = "alpaca", # The prompt template to use, will default to alpaca. prompt_template_name: str = "alpaca", # The prompt template to use, will default to alpaca.
gradient_checkpointing: bool = False, gradient_checkpointing: bool = False,
deepspeed: str = None, deepspeed: str = None,
qa_lora: bool = False, # if True, use qa-lora https://arxiv.org/abs/2309.14717
): ):
if int(os.environ.get("LOCAL_RANK", 0)) == 0: if int(os.environ.get("LOCAL_RANK", 0)) == 0:
print( print(
@ -135,6 +136,7 @@ def train(
f"wandb_log_model: {wandb_log_model}\n" f"wandb_log_model: {wandb_log_model}\n"
f"resume_from_checkpoint: {resume_from_checkpoint or False}\n" f"resume_from_checkpoint: {resume_from_checkpoint or False}\n"
f"prompt template: {prompt_template_name}\n" f"prompt template: {prompt_template_name}\n"
f"qa_lora: {qa_lora}\n"
) )
assert ( assert (
base_model base_model
@ -171,10 +173,13 @@ def train(
modules_to_not_convert=["lm_head"], modules_to_not_convert=["lm_head"],
) )
else: else:
# Load the base model from a directory or the HF Hub to 4-bit NormalFloat format # According to the QLoRA paper, using "nf4" could yield better model quality than "int4"
# Default 4-bit format for qa-lora is sym_int4
low_bit_format = "sym_int4" if qa_lora else "nf4"
# Load the base model from a directory or the HF Hub to 4-bit format
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
base_model, base_model,
load_in_low_bit="nf4", # According to the QLoRA paper, using "nf4" could yield better model quality than "int4" load_in_low_bit=low_bit_format,
optimize_model=False, optimize_model=False,
torch_dtype=torch.bfloat16, torch_dtype=torch.bfloat16,
# device_map=device_map, # device_map=device_map,
@ -252,7 +257,9 @@ def train(
lora_dropout=lora_dropout, lora_dropout=lora_dropout,
bias="none", bias="none",
task_type="CAUSAL_LM", task_type="CAUSAL_LM",
qa_lora=qa_lora,
) )
print(f"Lora Config: {config}")
model = get_peft_model(model, config) model = get_peft_model(model, config)
if data_path.endswith(".json") or data_path.endswith(".jsonl"): if data_path.endswith(".json") or data_path.endswith(".jsonl"):
@ -294,7 +301,7 @@ def train(
max_grad_norm=0.3, max_grad_norm=0.3,
num_train_epochs=num_epochs, num_train_epochs=num_epochs,
learning_rate=learning_rate, learning_rate=learning_rate,
lr_scheduler_type="cosine", lr_scheduler_type="constant" if qa_lora else "cosine",
bf16=True, # ensure training more stable bf16=True, # ensure training more stable
logging_steps=1, logging_steps=1,
optim="adamw_torch", optim="adamw_torch",

View file

@ -0,0 +1,29 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# You could also specify `--base_model` to the local path of the huggingface model checkpoint folder and `--data_path` to the local path of the dataset JSON file
python ./alpaca_qlora_finetuning.py \
--base_model "meta-llama/Llama-2-7b-hf" \
--data_path "yahma/alpaca-cleaned" \
--output_dir "./bigdl-qlora-alpaca" \
--learning_rate 9e-5 \
--micro_batch_size 2 \
--batch_size 128 \
--lora_r 8 \
--lora_alpha 16 \
--lora_dropout 0.05 \
--val_set_size 2000 \
--qa_lora True

View file

@ -0,0 +1,34 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
export MASTER_ADDR=127.0.0.1
export OMP_NUM_THREADS=6 # adjust this to 1/4 of total physical cores
export FI_PROVIDER=tcp
export CCL_ATL_TRANSPORT=ofi
mpirun -n 2 \
python -u ./alpaca_qlora_finetuning.py \
--base_model "meta-llama/Llama-2-7b-hf" \
--data_path "yahma/alpaca-cleaned" \
--output_dir "./bigdl-qlora-alpaca" \
--learning_rate 9e-5 \
--micro_batch_size 2 \
--batch_size 128 \
--lora_r 8 \
--lora_alpha 16 \
--lora_dropout 0.05 \
--val_set_size 2000 \
--qa_lora True > training.log

View file

@ -0,0 +1,34 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
export MASTER_ADDR=127.0.0.1
export OMP_NUM_THREADS=28 # adjust this to 1/4 of total physical cores
export FI_PROVIDER=tcp
export CCL_ATL_TRANSPORT=ofi
mpirun -n 2 \
python -u ./alpaca_qlora_finetuning.py \
--base_model "meta-llama/Llama-2-7b-hf" \
--data_path "yahma/alpaca-cleaned" \
--output_dir "./bigdl-qlora-alpaca" \
--qa_lora True \
--learning_rate 9e-5 \
--micro_batch_size 8 \
--batch_size 128 \
--lora_r 8 \
--lora_alpha 16 \
--lora_dropout 0.05 \
--val_set_size 2000 > training.log

View file

@ -0,0 +1,31 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# You could also specify `--base_model` to the local path of the huggingface model checkpoint folder and `--data_path` to the local path of the dataset JSON file
python ./alpaca_qlora_finetuning.py \
--base_model "meta-llama/Llama-2-7b-hf" \
--data_path "yahma/alpaca-cleaned" \
--output_dir "./bigdl-qlora-alpaca" \
--learning_rate 9e-5 \
--micro_batch_size 8 \
--batch_size 128 \
--gradient_checkpointing False \
--lora_r 8 \
--lora_alpha 16 \
--lora_dropout 0.05 \
--val_set_size 2000 \
--qa_lora True

View file

@ -31,11 +31,13 @@
import os import os
import torch import torch
import transformers
from transformers import LlamaTokenizer # noqa: F402 from transformers import LlamaTokenizer # noqa: F402
from bigdl.llm.transformers.qlora import PeftModel from bigdl.llm.transformers.qlora import PeftModel, LoraConfig
from bigdl.llm.transformers import AutoModelForCausalLM from bigdl.llm.transformers import AutoModelForCausalLM
from bigdl.llm.transformers.low_bit_linear import get_block_size
import argparse import argparse
import tempfile
import shutil
if __name__ == "__main__": if __name__ == "__main__":
@ -51,31 +53,66 @@ if __name__ == "__main__":
adapter_path = args.adapter_path adapter_path = args.adapter_path
tokenizer = LlamaTokenizer.from_pretrained(base_model) tokenizer = LlamaTokenizer.from_pretrained(base_model)
base_model = AutoModelForCausalLM.from_pretrained( lora_config = LoraConfig.from_json_file(os.path.join(adapter_path, "adapter_config.json"))
base_model, qa_lora = lora_config.get("qa_lora", False)
# load_in_low_bit="nf4", # should load the orignal model
torch_dtype=torch.float16,
device_map={"": "cpu"},
)
lora_model = PeftModel.from_pretrained( temp_dir = None
base_model, if qa_lora:
adapter_path, # Convert the qa-lora adapter to the correct shapes
device_map={"": "cpu"}, # The default 4-bit format for qa_lora is sym_int4
torch_dtype=torch.float16, block_size = get_block_size("sym_int4")
) temp_dir = tempfile.TemporaryDirectory()
tmpdirname = os.path.join(temp_dir.name, "adapter")
try:
shutil.copytree(adapter_path, tmpdirname)
except Exception as e:
print(f"Failed to copy adapter dir, error: {e}")
mid_lora_path = os.path.join(tmpdirname, "adapter_model.bin")
# merge weights - new merging method from peft adapter_path = os.path.join(adapter_path, "adapter_model.bin")
lora_model = lora_model.merge_and_unload()
lora_model.train(False) lora = torch.load(adapter_path, map_location='cpu')
# Get lora_a names
tmp_keys = [key for key in lora.keys() if 'lora_A' in key]
lora_model_sd = lora_model.state_dict() for tmp_key in tmp_keys:
deloreanized_sd = { lora_a = lora[tmp_key] / block_size
k.replace("base_model.model.", ""): v lora[tmp_key] = torch.repeat_interleave(lora_a, block_size, dim=1)
for k, v in lora_model_sd.items()
if "lora" not in k
}
base_model.save_pretrained(args.output_path, state_dict=deloreanized_sd) torch.save(lora, mid_lora_path)
tokenizer.save_pretrained(args.output_path) adapter_path = tmpdirname
try:
base_model = AutoModelForCausalLM.from_pretrained(
base_model,
# load_in_low_bit="nf4", # should load the orignal model
torch_dtype=torch.float16,
device_map={"": "cpu"},
)
lora_model = PeftModel.from_pretrained(
base_model,
adapter_path,
device_map={"": "cpu"},
torch_dtype=torch.float16,
)
# merge weights - new merging method from peft
lora_model = lora_model.merge_and_unload()
lora_model.train(False)
lora_model_sd = lora_model.state_dict()
deloreanized_sd = {
k.replace("base_model.model.", ""): v
for k, v in lora_model_sd.items()
if "lora" not in k
}
base_model.save_pretrained(args.output_path, state_dict=deloreanized_sd)
tokenizer.save_pretrained(args.output_path)
except Exception as e:
print(f"Failed to merge the adapter, error: {e}.")
finally:
if qa_lora and temp_dir:
temp_dir.cleanup()

View file

@ -19,9 +19,9 @@ import os
import transformers import transformers
from transformers import LlamaTokenizer from transformers import LlamaTokenizer
from peft import LoraConfig
import intel_extension_for_pytorch as ipex import intel_extension_for_pytorch as ipex
from bigdl.llm.transformers.qlora import get_peft_model, prepare_model_for_kbit_training from bigdl.llm.transformers.qlora import get_peft_model, prepare_model_for_kbit_training,\
LoraConfig
from bigdl.llm.transformers import AutoModelForCausalLM from bigdl.llm.transformers import AutoModelForCausalLM
from datasets import load_dataset from datasets import load_dataset
import argparse import argparse

View file

@ -109,8 +109,8 @@ def is_linear_module(module):
def convert_gptq(module, awq=False): def convert_gptq(module, awq=False):
from bigdl.llm.transformers.low_bit_linear import get_ggml_qk_size from bigdl.llm.transformers.low_bit_linear import get_block_size
Q4_1 = get_ggml_qk_size("asym_int4") Q4_1 = get_block_size("asym_int4")
scales = module.scales scales = module.scales

View file

@ -71,10 +71,14 @@ MOFQ4 = ggml_tensor_qtype["mixed_fp4"]
MOFQ8 = ggml_tensor_qtype["mixed_fp8"] MOFQ8 = ggml_tensor_qtype["mixed_fp8"]
def get_ggml_qk_size(qtype: str): def get_block_size(qtype: str):
return ggml.ggml_qk_size(ggml_tensor_qtype[qtype]) return ggml.ggml_qk_size(ggml_tensor_qtype[qtype])
def get_qk_size(qtype: int):
return ggml.ggml_qk_size(qtype)
def ggml_convert_qtype(tensor: torch.Tensor, qtype: int, def ggml_convert_qtype(tensor: torch.Tensor, qtype: int,
device=None, convert_shape_only=False): device=None, convert_shape_only=False):
QK = ggml.ggml_qk_size(qtype) QK = ggml.ggml_qk_size(qtype)

View file

@ -124,7 +124,7 @@ class _BaseAutoModelClass:
if load_in_4bit or load_in_low_bit: if load_in_4bit or load_in_low_bit:
if config_dict.get("quantization_config", None) is not None: if config_dict.get("quantization_config", None) is not None:
from bigdl.llm.transformers.low_bit_linear import get_ggml_qk_size from bigdl.llm.transformers.low_bit_linear import get_block_size
q_config = config_dict["quantization_config"] q_config = config_dict["quantization_config"]
if q_config["quant_method"] == "gptq": if q_config["quant_method"] == "gptq":
invalidInputError(q_config["bits"] == 4, invalidInputError(q_config["bits"] == 4,
@ -136,10 +136,10 @@ class _BaseAutoModelClass:
"You can only load gptq model as aysm_int4 low bit type.") "You can only load gptq model as aysm_int4 low bit type.")
load_in_low_bit = "asym_int4" load_in_low_bit = "asym_int4"
if int(q_config["group_size"]) % get_ggml_qk_size(load_in_low_bit) != 0: if int(q_config["group_size"]) % get_block_size(load_in_low_bit) != 0:
invalidInputError(False, invalidInputError(False,
(f"group_size must be divisible by " (f"group_size must be divisible by "
f"{get_ggml_qk_size(load_in_low_bit)}.")) f"{get_block_size(load_in_low_bit)}."))
if user_quantization_config is not None: if user_quantization_config is not None:
invalidInputError(user_quantization_config.bits == 4, invalidInputError(user_quantization_config.bits == 4,
"Only 4-bit gptq is supported in bigdl-llm.") "Only 4-bit gptq is supported in bigdl-llm.")
@ -166,10 +166,10 @@ class _BaseAutoModelClass:
load_in_low_bit = "asym_int4" load_in_low_bit = "asym_int4"
if int(awq_config.group_size) % get_ggml_qk_size(load_in_low_bit) != 0: if int(awq_config.group_size) % get_block_size(load_in_low_bit) != 0:
invalidInputError(False, invalidInputError(False,
(f"group_size must be divisible by " (f"group_size must be divisible by "
f"{get_ggml_qk_size(load_in_low_bit)}.")) f"{get_block_size(load_in_low_bit)}."))
kwargs["quantization_config"] = awq_config kwargs["quantization_config"] = awq_config

View file

@ -49,7 +49,7 @@
# limitations under the License. # limitations under the License.
import torch import torch
from bigdl.llm.transformers.low_bit_linear import LowBitLinear from bigdl.llm.transformers.low_bit_linear import LowBitLinear, get_qk_size
from peft.tuners.lora import LoraLayer from peft.tuners.lora import LoraLayer
from bigdl.llm.utils.common import invalidInputError from bigdl.llm.utils.common import invalidInputError
from bigdl.llm.transformers.utils import get_autocast_dtype from bigdl.llm.transformers.utils import get_autocast_dtype
@ -66,6 +66,7 @@ class LoraLowBitLinear(LowBitLinear, LoraLayer):
r: int = 0, r: int = 0,
lora_alpha: int = 1, lora_alpha: int = 1,
lora_dropout: float = 0.0, lora_dropout: float = 0.0,
qa_lora: bool = True,
**kwargs, **kwargs,
): ):
LowBitLinear.__init__( LowBitLinear.__init__(
@ -76,7 +77,10 @@ class LoraLowBitLinear(LowBitLinear, LoraLayer):
bias=kwargs.get("bias", True), bias=kwargs.get("bias", True),
conver_to_half=False, conver_to_half=False,
) )
LoraLayer.__init__(self, in_features=in_features, out_features=out_features)
qk_size = get_qk_size(kwargs.get("qtype"))
lora_in_features = in_features // qk_size if qa_lora else in_features
LoraLayer.__init__(self, in_features=lora_in_features, out_features=out_features)
# Freezing the pre-trained weight matrix # Freezing the pre-trained weight matrix
self.weight.requires_grad = False self.weight.requires_grad = False
@ -84,6 +88,10 @@ class LoraLowBitLinear(LowBitLinear, LoraLayer):
init_lora_weights = kwargs.pop("init_lora_weights", True) init_lora_weights = kwargs.pop("init_lora_weights", True)
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights) self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
self.active_adapter = adapter_name self.active_adapter = adapter_name
if qa_lora:
self.qa_pool = torch.nn.AvgPool1d(qk_size)
else:
self.qa_pool = torch.nn.Identity()
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
autocast_dtype = get_autocast_dtype(x) autocast_dtype = get_autocast_dtype(x)
@ -103,14 +111,16 @@ class LoraLowBitLinear(LowBitLinear, LoraLayer):
x = x.to(self.lora_A[self.active_adapter].weight.dtype) x = x.to(self.lora_A[self.active_adapter].weight.dtype)
output = ( output = (
self.lora_B[self.active_adapter]( self.lora_B[self.active_adapter](
self.lora_A[self.active_adapter](self.lora_dropout[self.active_adapter](x)) self.lora_A[self.active_adapter](
self.lora_dropout[self.active_adapter](self.qa_pool(x)))
).to(expected_dtype) ).to(expected_dtype)
* self.scaling[self.active_adapter] * self.scaling[self.active_adapter]
) )
else: else:
output = ( output = (
self.lora_B[self.active_adapter]( self.lora_B[self.active_adapter](
self.lora_A[self.active_adapter](self.lora_dropout[self.active_adapter](x)) self.lora_A[self.active_adapter](
self.lora_dropout[self.active_adapter](self.qa_pool(x)))
) )
* self.scaling[self.active_adapter] * self.scaling[self.active_adapter]
) )
@ -126,6 +136,7 @@ def _create_new_module(create_new_module_func, lora_config, adapter_name, target
low_bit_kwargs.update( low_bit_kwargs.update(
{ {
"qtype": target.qtype, "qtype": target.qtype,
"qa_lora": lora_config.qa_lora,
} }
) )
new_module = LoraLowBitLinear(adapter_name, new_module = LoraLowBitLinear(adapter_name,
@ -140,6 +151,14 @@ def _create_new_module(create_new_module_func, lora_config, adapter_name, target
from peft.tuners.lora import LoraModel from peft.tuners.lora import LoraModel
from peft.tuners.lora import LoraConfig as LoraConfigBase
from dataclasses import dataclass, field
@dataclass
class LoraConfig(LoraConfigBase):
qa_lora: bool = field(default=False, metadata={"help": "enable qa-lora"})
def get_peft_model(*args, **kwargs): def get_peft_model(*args, **kwargs):
@ -357,6 +376,10 @@ def _setup_devices(self) -> "torch.device":
torch.cuda.set_device(device) torch.cuda.set_device(device)
return device return device
from peft.mapping import PEFT_TYPE_TO_CONFIG_MAPPING
PEFT_TYPE_TO_CONFIG_MAPPING["lora"] = LoraConfig
# workaround a IPEX bug that prevents resume training in bf16 # workaround a IPEX bug that prevents resume training in bf16
from accelerate import Accelerator from accelerate import Accelerator
Accelerator._prepare_ipex = patch_prepare_ipex Accelerator._prepare_ipex = patch_prepare_ipex