QALora example (#9551)
* Support qa-lora * init * update * update * update * update * update * update merge * update * fix style & update scripts * update * address comments * fix typo * fix typo --------- Co-authored-by: Yang Wang <yang3.wang@intel.com>
This commit is contained in:
parent
6978b2c316
commit
404e101ded
12 changed files with 281 additions and 54 deletions
|
|
@ -1,6 +1,6 @@
|
|||
# Alpaca QLoRA Finetuning (experimental support)
|
||||
# Alpaca QLoRA & QA-LoRA Finetuning (experimental support)
|
||||
|
||||
This example ports [Alpaca-LoRA](https://github.com/tloen/alpaca-lora/tree/main) to BigDL-LLM QLoRA on [Intel GPUs](../../README.md).
|
||||
This example ports [Alpaca-LoRA](https://github.com/tloen/alpaca-lora/tree/main) to BigDL-LLM (using either [QLoRA](https://arxiv.org/abs/2305.14314) or [QA-LoRA](https://arxiv.org/abs/2309.14717) algorithm) on [Intel GPU](../../README.md).
|
||||
|
||||
### 0. Requirements
|
||||
To run this example with BigDL-LLM on Intel GPUs, we have some recommended requirements for your machine, please refer to [here](../../README.md#requirements) for more information.
|
||||
|
|
@ -28,54 +28,75 @@ source /opt/intel/oneapi/setvars.sh
|
|||
|
||||
Here, we provide example usages on different hardware. Please refer to the appropriate script based on your device:
|
||||
|
||||
#### Finetuning LLaMA2-7B on single Arc A770
|
||||
#### QLoRA
|
||||
|
||||
##### Finetuning LLaMA2-7B on single Arc A770
|
||||
|
||||
```bash
|
||||
bash finetune_llama2_7b_arc_1_card.sh
|
||||
```
|
||||
|
||||
#### Finetuning LLaMA2-7B on two Arc A770
|
||||
##### Finetuning LLaMA2-7B on two Arc A770
|
||||
|
||||
```bash
|
||||
bash finetune_llama2_7b_arc_2_card.sh
|
||||
```
|
||||
|
||||
#### Finetuning LLaMA2-7B on single Data Center GPU Flex 170
|
||||
##### Finetuning LLaMA2-7B on single Data Center GPU Flex 170
|
||||
|
||||
```bash
|
||||
bash finetune_llama2_7b_flex_170_1_card.sh
|
||||
```
|
||||
|
||||
#### Finetuning LLaMA2-7B on three Data Center GPU Flex 170
|
||||
##### Finetuning LLaMA2-7B on three Data Center GPU Flex 170
|
||||
|
||||
```bash
|
||||
bash finetune_llama2_7b_flex_170_3_card.sh
|
||||
```
|
||||
|
||||
#### Finetuning LLaMA2-7B on single Intel Data Center GPU Max 1100
|
||||
##### Finetuning LLaMA2-7B on single Intel Data Center GPU Max 1100
|
||||
|
||||
```bash
|
||||
bash finetune_llama2_7b_pvc_1100_1_card.sh
|
||||
```
|
||||
|
||||
#### Finetuning LLaMA2-7B on four Intel Data Center GPU Max 1100
|
||||
##### Finetuning LLaMA2-7B on four Intel Data Center GPU Max 1100
|
||||
|
||||
```bash
|
||||
bash finetune_llama2_7b_pvc_1100_4_card.sh
|
||||
```
|
||||
|
||||
#### Finetuning LLaMA2-7B on single Intel Data Center GPU Max 1550
|
||||
##### Finetuning LLaMA2-7B on single Intel Data Center GPU Max 1550
|
||||
|
||||
```bash
|
||||
bash finetune_llama2_7b_pvc_1550_1_card.sh
|
||||
```
|
||||
|
||||
#### Finetuning LLaMA2-7B on four Intel Data Center GPU Max 1550
|
||||
##### Finetuning LLaMA2-7B on four Intel Data Center GPU Max 1550
|
||||
|
||||
```bash
|
||||
bash finetune_llama2_7b_pvc_1550_4_card.sh
|
||||
```
|
||||
|
||||
#### QA-LoRA
|
||||
##### Finetuning LLaMA2-7B on single Arc A770
|
||||
|
||||
```bash
|
||||
bash qalora_finetune_llama2_7b_arc_1_card.sh
|
||||
```
|
||||
|
||||
##### Finetuning LLaMA2-7B on two Arc A770
|
||||
|
||||
```bash
|
||||
bash qalora_finetune_llama2_7b_arc_2_card.sh
|
||||
```
|
||||
|
||||
##### Finetuning LLaMA2-7B on single Tile Intel Data Center GPU Max 1550
|
||||
|
||||
```bash
|
||||
bash qalora_finetune_llama2_7b_pvc_1550_1_tile.sh
|
||||
```
|
||||
|
||||
**Important: If you fail to complete the whole finetuning process, it is suggested to resume training from a previously saved checkpoint by specifying `resume_from_checkpoint` to the local checkpoint folder as following:**
|
||||
```bash
|
||||
python ./alpaca_qlora_finetuning.py \
|
||||
|
|
@ -97,3 +118,10 @@ python ./alpaca_qlora_finetuning.py \
|
|||
{'loss': 1.8552, 'learning_rate': 2.9996503623845395e-05, 'epoch': 0.02}
|
||||
1%|█ | 8/1164 [xx:xx<xx:xx:xx, xx s/it]
|
||||
```
|
||||
|
||||
### 4. Merge the adapter into the original model
|
||||
```
|
||||
python ./export_merged_model.py --repo-id-or-model-path REPO_ID_OR_MODEL_PATH --adapter_path ./outputs/checkpoint-200 --output_path ./outputs/checkpoint-200-merged
|
||||
```
|
||||
|
||||
Then you can use `./outputs/checkpoint-200-merged` as a normal huggingface transformer model to do inference.
|
||||
|
|
|
|||
|
|
@ -41,7 +41,6 @@ import accelerate
|
|||
|
||||
from transformers import LlamaTokenizer
|
||||
from peft import (
|
||||
LoraConfig,
|
||||
get_peft_model_state_dict,
|
||||
set_peft_model_state_dict,
|
||||
)
|
||||
|
|
@ -51,7 +50,8 @@ import intel_extension_for_pytorch as ipex
|
|||
from bigdl.llm.transformers import AutoModelForCausalLM
|
||||
|
||||
# import them from bigdl.llm.transformers.qlora to get a BigDL-LLM compatible Peft model
|
||||
from bigdl.llm.transformers.qlora import get_peft_model, prepare_model_for_kbit_training
|
||||
from bigdl.llm.transformers.qlora import get_peft_model, prepare_model_for_kbit_training,\
|
||||
cast_lora_weight, LoraConfig
|
||||
|
||||
def get_int_from_env(env_keys, default):
|
||||
"""Returns the first positive env value found in the `env_keys` list or the default."""
|
||||
|
|
@ -109,6 +109,7 @@ def train(
|
|||
prompt_template_name: str = "alpaca", # The prompt template to use, will default to alpaca.
|
||||
gradient_checkpointing: bool = False,
|
||||
deepspeed: str = None,
|
||||
qa_lora: bool = False, # if True, use qa-lora https://arxiv.org/abs/2309.14717
|
||||
):
|
||||
if int(os.environ.get("LOCAL_RANK", 0)) == 0:
|
||||
print(
|
||||
|
|
@ -135,6 +136,7 @@ def train(
|
|||
f"wandb_log_model: {wandb_log_model}\n"
|
||||
f"resume_from_checkpoint: {resume_from_checkpoint or False}\n"
|
||||
f"prompt template: {prompt_template_name}\n"
|
||||
f"qa_lora: {qa_lora}\n"
|
||||
)
|
||||
assert (
|
||||
base_model
|
||||
|
|
@ -171,10 +173,13 @@ def train(
|
|||
modules_to_not_convert=["lm_head"],
|
||||
)
|
||||
else:
|
||||
# Load the base model from a directory or the HF Hub to 4-bit NormalFloat format
|
||||
# According to the QLoRA paper, using "nf4" could yield better model quality than "int4"
|
||||
# Default 4-bit format for qa-lora is sym_int4
|
||||
low_bit_format = "sym_int4" if qa_lora else "nf4"
|
||||
# Load the base model from a directory or the HF Hub to 4-bit format
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
base_model,
|
||||
load_in_low_bit="nf4", # According to the QLoRA paper, using "nf4" could yield better model quality than "int4"
|
||||
load_in_low_bit=low_bit_format,
|
||||
optimize_model=False,
|
||||
torch_dtype=torch.bfloat16,
|
||||
# device_map=device_map,
|
||||
|
|
@ -252,7 +257,9 @@ def train(
|
|||
lora_dropout=lora_dropout,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
qa_lora=qa_lora,
|
||||
)
|
||||
print(f"Lora Config: {config}")
|
||||
model = get_peft_model(model, config)
|
||||
|
||||
if data_path.endswith(".json") or data_path.endswith(".jsonl"):
|
||||
|
|
@ -294,7 +301,7 @@ def train(
|
|||
max_grad_norm=0.3,
|
||||
num_train_epochs=num_epochs,
|
||||
learning_rate=learning_rate,
|
||||
lr_scheduler_type="cosine",
|
||||
lr_scheduler_type="constant" if qa_lora else "cosine",
|
||||
bf16=True, # ensure training more stable
|
||||
logging_steps=1,
|
||||
optim="adamw_torch",
|
||||
|
|
|
|||
|
|
@ -0,0 +1,29 @@
|
|||
#
|
||||
# Copyright 2016 The BigDL Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
# You could also specify `--base_model` to the local path of the huggingface model checkpoint folder and `--data_path` to the local path of the dataset JSON file
|
||||
python ./alpaca_qlora_finetuning.py \
|
||||
--base_model "meta-llama/Llama-2-7b-hf" \
|
||||
--data_path "yahma/alpaca-cleaned" \
|
||||
--output_dir "./bigdl-qlora-alpaca" \
|
||||
--learning_rate 9e-5 \
|
||||
--micro_batch_size 2 \
|
||||
--batch_size 128 \
|
||||
--lora_r 8 \
|
||||
--lora_alpha 16 \
|
||||
--lora_dropout 0.05 \
|
||||
--val_set_size 2000 \
|
||||
--qa_lora True
|
||||
|
|
@ -0,0 +1,34 @@
|
|||
#
|
||||
# Copyright 2016 The BigDL Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
export MASTER_ADDR=127.0.0.1
|
||||
export OMP_NUM_THREADS=6 # adjust this to 1/4 of total physical cores
|
||||
export FI_PROVIDER=tcp
|
||||
export CCL_ATL_TRANSPORT=ofi
|
||||
|
||||
mpirun -n 2 \
|
||||
python -u ./alpaca_qlora_finetuning.py \
|
||||
--base_model "meta-llama/Llama-2-7b-hf" \
|
||||
--data_path "yahma/alpaca-cleaned" \
|
||||
--output_dir "./bigdl-qlora-alpaca" \
|
||||
--learning_rate 9e-5 \
|
||||
--micro_batch_size 2 \
|
||||
--batch_size 128 \
|
||||
--lora_r 8 \
|
||||
--lora_alpha 16 \
|
||||
--lora_dropout 0.05 \
|
||||
--val_set_size 2000 \
|
||||
--qa_lora True > training.log
|
||||
|
|
@ -0,0 +1,34 @@
|
|||
#
|
||||
# Copyright 2016 The BigDL Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
export MASTER_ADDR=127.0.0.1
|
||||
export OMP_NUM_THREADS=28 # adjust this to 1/4 of total physical cores
|
||||
export FI_PROVIDER=tcp
|
||||
export CCL_ATL_TRANSPORT=ofi
|
||||
|
||||
mpirun -n 2 \
|
||||
python -u ./alpaca_qlora_finetuning.py \
|
||||
--base_model "meta-llama/Llama-2-7b-hf" \
|
||||
--data_path "yahma/alpaca-cleaned" \
|
||||
--output_dir "./bigdl-qlora-alpaca" \
|
||||
--qa_lora True \
|
||||
--learning_rate 9e-5 \
|
||||
--micro_batch_size 8 \
|
||||
--batch_size 128 \
|
||||
--lora_r 8 \
|
||||
--lora_alpha 16 \
|
||||
--lora_dropout 0.05 \
|
||||
--val_set_size 2000 > training.log
|
||||
|
|
@ -0,0 +1,31 @@
|
|||
#
|
||||
# Copyright 2016 The BigDL Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
# You could also specify `--base_model` to the local path of the huggingface model checkpoint folder and `--data_path` to the local path of the dataset JSON file
|
||||
|
||||
python ./alpaca_qlora_finetuning.py \
|
||||
--base_model "meta-llama/Llama-2-7b-hf" \
|
||||
--data_path "yahma/alpaca-cleaned" \
|
||||
--output_dir "./bigdl-qlora-alpaca" \
|
||||
--learning_rate 9e-5 \
|
||||
--micro_batch_size 8 \
|
||||
--batch_size 128 \
|
||||
--gradient_checkpointing False \
|
||||
--lora_r 8 \
|
||||
--lora_alpha 16 \
|
||||
--lora_dropout 0.05 \
|
||||
--val_set_size 2000 \
|
||||
--qa_lora True
|
||||
|
|
@ -31,11 +31,13 @@
|
|||
import os
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from transformers import LlamaTokenizer # noqa: F402
|
||||
from bigdl.llm.transformers.qlora import PeftModel
|
||||
from bigdl.llm.transformers.qlora import PeftModel, LoraConfig
|
||||
from bigdl.llm.transformers import AutoModelForCausalLM
|
||||
from bigdl.llm.transformers.low_bit_linear import get_block_size
|
||||
import argparse
|
||||
import tempfile
|
||||
import shutil
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
|
|
@ -51,6 +53,36 @@ if __name__ == "__main__":
|
|||
adapter_path = args.adapter_path
|
||||
tokenizer = LlamaTokenizer.from_pretrained(base_model)
|
||||
|
||||
lora_config = LoraConfig.from_json_file(os.path.join(adapter_path, "adapter_config.json"))
|
||||
qa_lora = lora_config.get("qa_lora", False)
|
||||
|
||||
temp_dir = None
|
||||
if qa_lora:
|
||||
# Convert the qa-lora adapter to the correct shapes
|
||||
# The default 4-bit format for qa_lora is sym_int4
|
||||
block_size = get_block_size("sym_int4")
|
||||
temp_dir = tempfile.TemporaryDirectory()
|
||||
tmpdirname = os.path.join(temp_dir.name, "adapter")
|
||||
try:
|
||||
shutil.copytree(adapter_path, tmpdirname)
|
||||
except Exception as e:
|
||||
print(f"Failed to copy adapter dir, error: {e}")
|
||||
mid_lora_path = os.path.join(tmpdirname, "adapter_model.bin")
|
||||
|
||||
adapter_path = os.path.join(adapter_path, "adapter_model.bin")
|
||||
|
||||
lora = torch.load(adapter_path, map_location='cpu')
|
||||
# Get lora_a names
|
||||
tmp_keys = [key for key in lora.keys() if 'lora_A' in key]
|
||||
|
||||
for tmp_key in tmp_keys:
|
||||
lora_a = lora[tmp_key] / block_size
|
||||
lora[tmp_key] = torch.repeat_interleave(lora_a, block_size, dim=1)
|
||||
|
||||
torch.save(lora, mid_lora_path)
|
||||
adapter_path = tmpdirname
|
||||
|
||||
try:
|
||||
base_model = AutoModelForCausalLM.from_pretrained(
|
||||
base_model,
|
||||
# load_in_low_bit="nf4", # should load the orignal model
|
||||
|
|
@ -79,3 +111,8 @@ if __name__ == "__main__":
|
|||
|
||||
base_model.save_pretrained(args.output_path, state_dict=deloreanized_sd)
|
||||
tokenizer.save_pretrained(args.output_path)
|
||||
except Exception as e:
|
||||
print(f"Failed to merge the adapter, error: {e}.")
|
||||
finally:
|
||||
if qa_lora and temp_dir:
|
||||
temp_dir.cleanup()
|
||||
|
|
|
|||
|
|
@ -19,9 +19,9 @@ import os
|
|||
|
||||
import transformers
|
||||
from transformers import LlamaTokenizer
|
||||
from peft import LoraConfig
|
||||
import intel_extension_for_pytorch as ipex
|
||||
from bigdl.llm.transformers.qlora import get_peft_model, prepare_model_for_kbit_training
|
||||
from bigdl.llm.transformers.qlora import get_peft_model, prepare_model_for_kbit_training,\
|
||||
LoraConfig
|
||||
from bigdl.llm.transformers import AutoModelForCausalLM
|
||||
from datasets import load_dataset
|
||||
import argparse
|
||||
|
|
|
|||
|
|
@ -109,8 +109,8 @@ def is_linear_module(module):
|
|||
|
||||
|
||||
def convert_gptq(module, awq=False):
|
||||
from bigdl.llm.transformers.low_bit_linear import get_ggml_qk_size
|
||||
Q4_1 = get_ggml_qk_size("asym_int4")
|
||||
from bigdl.llm.transformers.low_bit_linear import get_block_size
|
||||
Q4_1 = get_block_size("asym_int4")
|
||||
|
||||
scales = module.scales
|
||||
|
||||
|
|
|
|||
|
|
@ -71,10 +71,14 @@ MOFQ4 = ggml_tensor_qtype["mixed_fp4"]
|
|||
MOFQ8 = ggml_tensor_qtype["mixed_fp8"]
|
||||
|
||||
|
||||
def get_ggml_qk_size(qtype: str):
|
||||
def get_block_size(qtype: str):
|
||||
return ggml.ggml_qk_size(ggml_tensor_qtype[qtype])
|
||||
|
||||
|
||||
def get_qk_size(qtype: int):
|
||||
return ggml.ggml_qk_size(qtype)
|
||||
|
||||
|
||||
def ggml_convert_qtype(tensor: torch.Tensor, qtype: int,
|
||||
device=None, convert_shape_only=False):
|
||||
QK = ggml.ggml_qk_size(qtype)
|
||||
|
|
|
|||
|
|
@ -124,7 +124,7 @@ class _BaseAutoModelClass:
|
|||
if load_in_4bit or load_in_low_bit:
|
||||
|
||||
if config_dict.get("quantization_config", None) is not None:
|
||||
from bigdl.llm.transformers.low_bit_linear import get_ggml_qk_size
|
||||
from bigdl.llm.transformers.low_bit_linear import get_block_size
|
||||
q_config = config_dict["quantization_config"]
|
||||
if q_config["quant_method"] == "gptq":
|
||||
invalidInputError(q_config["bits"] == 4,
|
||||
|
|
@ -136,10 +136,10 @@ class _BaseAutoModelClass:
|
|||
"You can only load gptq model as aysm_int4 low bit type.")
|
||||
|
||||
load_in_low_bit = "asym_int4"
|
||||
if int(q_config["group_size"]) % get_ggml_qk_size(load_in_low_bit) != 0:
|
||||
if int(q_config["group_size"]) % get_block_size(load_in_low_bit) != 0:
|
||||
invalidInputError(False,
|
||||
(f"group_size must be divisible by "
|
||||
f"{get_ggml_qk_size(load_in_low_bit)}."))
|
||||
f"{get_block_size(load_in_low_bit)}."))
|
||||
if user_quantization_config is not None:
|
||||
invalidInputError(user_quantization_config.bits == 4,
|
||||
"Only 4-bit gptq is supported in bigdl-llm.")
|
||||
|
|
@ -166,10 +166,10 @@ class _BaseAutoModelClass:
|
|||
|
||||
load_in_low_bit = "asym_int4"
|
||||
|
||||
if int(awq_config.group_size) % get_ggml_qk_size(load_in_low_bit) != 0:
|
||||
if int(awq_config.group_size) % get_block_size(load_in_low_bit) != 0:
|
||||
invalidInputError(False,
|
||||
(f"group_size must be divisible by "
|
||||
f"{get_ggml_qk_size(load_in_low_bit)}."))
|
||||
f"{get_block_size(load_in_low_bit)}."))
|
||||
|
||||
kwargs["quantization_config"] = awq_config
|
||||
|
||||
|
|
|
|||
|
|
@ -49,7 +49,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
from bigdl.llm.transformers.low_bit_linear import LowBitLinear
|
||||
from bigdl.llm.transformers.low_bit_linear import LowBitLinear, get_qk_size
|
||||
from peft.tuners.lora import LoraLayer
|
||||
from bigdl.llm.utils.common import invalidInputError
|
||||
from bigdl.llm.transformers.utils import get_autocast_dtype
|
||||
|
|
@ -66,6 +66,7 @@ class LoraLowBitLinear(LowBitLinear, LoraLayer):
|
|||
r: int = 0,
|
||||
lora_alpha: int = 1,
|
||||
lora_dropout: float = 0.0,
|
||||
qa_lora: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
LowBitLinear.__init__(
|
||||
|
|
@ -76,7 +77,10 @@ class LoraLowBitLinear(LowBitLinear, LoraLayer):
|
|||
bias=kwargs.get("bias", True),
|
||||
conver_to_half=False,
|
||||
)
|
||||
LoraLayer.__init__(self, in_features=in_features, out_features=out_features)
|
||||
|
||||
qk_size = get_qk_size(kwargs.get("qtype"))
|
||||
lora_in_features = in_features // qk_size if qa_lora else in_features
|
||||
LoraLayer.__init__(self, in_features=lora_in_features, out_features=out_features)
|
||||
|
||||
# Freezing the pre-trained weight matrix
|
||||
self.weight.requires_grad = False
|
||||
|
|
@ -84,6 +88,10 @@ class LoraLowBitLinear(LowBitLinear, LoraLayer):
|
|||
init_lora_weights = kwargs.pop("init_lora_weights", True)
|
||||
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
|
||||
self.active_adapter = adapter_name
|
||||
if qa_lora:
|
||||
self.qa_pool = torch.nn.AvgPool1d(qk_size)
|
||||
else:
|
||||
self.qa_pool = torch.nn.Identity()
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
autocast_dtype = get_autocast_dtype(x)
|
||||
|
|
@ -103,14 +111,16 @@ class LoraLowBitLinear(LowBitLinear, LoraLayer):
|
|||
x = x.to(self.lora_A[self.active_adapter].weight.dtype)
|
||||
output = (
|
||||
self.lora_B[self.active_adapter](
|
||||
self.lora_A[self.active_adapter](self.lora_dropout[self.active_adapter](x))
|
||||
self.lora_A[self.active_adapter](
|
||||
self.lora_dropout[self.active_adapter](self.qa_pool(x)))
|
||||
).to(expected_dtype)
|
||||
* self.scaling[self.active_adapter]
|
||||
)
|
||||
else:
|
||||
output = (
|
||||
self.lora_B[self.active_adapter](
|
||||
self.lora_A[self.active_adapter](self.lora_dropout[self.active_adapter](x))
|
||||
self.lora_A[self.active_adapter](
|
||||
self.lora_dropout[self.active_adapter](self.qa_pool(x)))
|
||||
)
|
||||
* self.scaling[self.active_adapter]
|
||||
)
|
||||
|
|
@ -126,6 +136,7 @@ def _create_new_module(create_new_module_func, lora_config, adapter_name, target
|
|||
low_bit_kwargs.update(
|
||||
{
|
||||
"qtype": target.qtype,
|
||||
"qa_lora": lora_config.qa_lora,
|
||||
}
|
||||
)
|
||||
new_module = LoraLowBitLinear(adapter_name,
|
||||
|
|
@ -140,6 +151,14 @@ def _create_new_module(create_new_module_func, lora_config, adapter_name, target
|
|||
|
||||
|
||||
from peft.tuners.lora import LoraModel
|
||||
from peft.tuners.lora import LoraConfig as LoraConfigBase
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoraConfig(LoraConfigBase):
|
||||
|
||||
qa_lora: bool = field(default=False, metadata={"help": "enable qa-lora"})
|
||||
|
||||
|
||||
def get_peft_model(*args, **kwargs):
|
||||
|
|
@ -357,6 +376,10 @@ def _setup_devices(self) -> "torch.device":
|
|||
torch.cuda.set_device(device)
|
||||
return device
|
||||
|
||||
from peft.mapping import PEFT_TYPE_TO_CONFIG_MAPPING
|
||||
|
||||
PEFT_TYPE_TO_CONFIG_MAPPING["lora"] = LoraConfig
|
||||
|
||||
# workaround a IPEX bug that prevents resume training in bf16
|
||||
from accelerate import Accelerator
|
||||
Accelerator._prepare_ipex = patch_prepare_ipex
|
||||
|
|
|
|||
Loading…
Reference in a new issue