Experiment XPU QLora Finetuning (#8937)
* Support xpu finetuning * support xpu finetuning * fix style * fix style * fix style * refine example * add readme * refine readme * refine api * fix fp16 * fix example * refactor * fix style * fix compute type * add qlora * refine training args * fix example * fix style * fast path forinference * address comments * refine readme * revert lint
This commit is contained in:
		
							parent
							
								
									51518e029d
								
							
						
					
					
						commit
						c88f6ec457
					
				
					 6 changed files with 373 additions and 8 deletions
				
			
		
							
								
								
									
										50
									
								
								python/llm/example/gpu/qlora_finetuning/README.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										50
									
								
								python/llm/example/gpu/qlora_finetuning/README.md
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,50 @@
 | 
			
		|||
# Q-Lora (experimental support)
 | 
			
		||||
 | 
			
		||||
This example demonstrates how to finetune a llama2-7b model use Big-LLM 4bit optimizations using [Intel GPUs](../README.md).
 | 
			
		||||
 | 
			
		||||
## 0. Requirements
 | 
			
		||||
To run this example with BigDL-LLM on Intel GPUs, we have some recommended requirements for your machine, please refer to [here](../README.md#recommended-requirements) for more information.
 | 
			
		||||
 | 
			
		||||
## Example: Finetune llama2-7b using qlora
 | 
			
		||||
 | 
			
		||||
This example is ported from [bnb-4bit-training](https://colab.research.google.com/drive/1VoYNfYDKcKRQRor98Zbf2-9VQTtGJ24k?usp=sharing)
 | 
			
		||||
 | 
			
		||||
### 1. Install
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
conda create -n llm python=3.9
 | 
			
		||||
conda activate llm
 | 
			
		||||
# below command will install intel_extension_for_pytorch==2.0.110+xpu as default
 | 
			
		||||
# you can install specific ipex/torch version for your need
 | 
			
		||||
pip install --pre --upgrade bigdl-llm[xpu] -f https://developer.intel.com/ipex-whl-stable-xpu
 | 
			
		||||
pip install git+https://github.com/huggingface/transformers.git@95fe0f5
 | 
			
		||||
pip install peft==0.5.0
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
### 2. Configures OneAPI environment variables
 | 
			
		||||
```bash
 | 
			
		||||
source /opt/intel/oneapi/setvars.sh
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
### 3. Run
 | 
			
		||||
 | 
			
		||||
```
 | 
			
		||||
python ./qlora_finetuning.py --repo-id-or-model-path REPO_ID_OR_MODEL_PATH
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
### Sample Output
 | 
			
		||||
```log
 | 
			
		||||
{'loss': 1.6134, 'learning_rate': 0.0002, 'epoch': 0.03}                                                                                 
 | 
			
		||||
{'loss': 1.3038, 'learning_rate': 0.00017777777777777779, 'epoch': 0.06}                                                                 
 | 
			
		||||
{'loss': 1.2634, 'learning_rate': 0.00015555555555555556, 'epoch': 0.1}                                                                  
 | 
			
		||||
{'loss': 1.2389, 'learning_rate': 0.00013333333333333334, 'epoch': 0.13}                                                                 
 | 
			
		||||
{'loss': 1.0399, 'learning_rate': 0.00011111111111111112, 'epoch': 0.16}                                                                 
 | 
			
		||||
{'loss': 1.0406, 'learning_rate': 8.888888888888889e-05, 'epoch': 0.19}                                                                  
 | 
			
		||||
{'loss': 1.3114, 'learning_rate': 6.666666666666667e-05, 'epoch': 0.22}                                                                  
 | 
			
		||||
{'loss': 0.9876, 'learning_rate': 4.4444444444444447e-05, 'epoch': 0.26}                                                                 
 | 
			
		||||
{'loss': 1.1406, 'learning_rate': 2.2222222222222223e-05, 'epoch': 0.29}                                                                 
 | 
			
		||||
{'loss': 1.1728, 'learning_rate': 0.0, 'epoch': 0.32}                                                                                    
 | 
			
		||||
{'train_runtime': 225.8005, 'train_samples_per_second': 3.543, 'train_steps_per_second': 0.886, 'train_loss': 1.211241865158081, 'epoch': 0.32}
 | 
			
		||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [03:45<00:00,  1.13s/it]
 | 
			
		||||
TrainOutput(global_step=200, training_loss=1.211241865158081, metrics={'train_runtime': 225.8005, 'train_samples_per_second': 3.543, 'train_steps_per_second': 0.886, 'train_loss': 1.211241865158081, 'epoch': 0.32})
 | 
			
		||||
```
 | 
			
		||||
							
								
								
									
										84
									
								
								python/llm/example/gpu/qlora_finetuning/qlora_finetuning.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										84
									
								
								python/llm/example/gpu/qlora_finetuning/qlora_finetuning.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,84 @@
 | 
			
		|||
#
 | 
			
		||||
# Copyright 2016 The BigDL Authors.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
#
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import os
 | 
			
		||||
os.environ["ACCELERATE_USE_IPEX"] = "true"
 | 
			
		||||
os.environ["ACCELERATE_USE_XPU"] = "true"
 | 
			
		||||
 | 
			
		||||
import transformers
 | 
			
		||||
from transformers import LlamaTokenizer
 | 
			
		||||
 | 
			
		||||
from peft import LoraConfig
 | 
			
		||||
import intel_extension_for_pytorch as ipex
 | 
			
		||||
from bigdl.llm.transformers.qlora import get_peft_model, prepare_model_for_kbit_training
 | 
			
		||||
from bigdl.llm.transformers import AutoModelForCausalLM
 | 
			
		||||
from datasets import load_dataset
 | 
			
		||||
import argparse
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
 | 
			
		||||
    parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for Llama2 model')
 | 
			
		||||
    parser.add_argument('--repo-id-or-model-path', type=str, default="meta-llama/Llama-2-7b-hf",
 | 
			
		||||
                        help='The huggingface repo id for the Llama2 (e.g. `meta-llama/Llama-2-7b-hf` and `meta-llama/Llama-2-13b-chat-hf`) to be downloaded'
 | 
			
		||||
                             ', or the path to the huggingface checkpoint folder')
 | 
			
		||||
    parser.add_argument('--dataset', type=str, default="Abirate/english_quotes")
 | 
			
		||||
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
    model_path = args.repo_id_or_model_path
 | 
			
		||||
    dataset_path = args.dataset
 | 
			
		||||
    tokenizer = LlamaTokenizer.from_pretrained(model_path, trust_remote_code=True)
 | 
			
		||||
 | 
			
		||||
    data = load_dataset(dataset_path)
 | 
			
		||||
    data = data.map(lambda samples: tokenizer(samples["quote"]), batched=True)
 | 
			
		||||
    model = AutoModelForCausalLM.from_pretrained(model_path,
 | 
			
		||||
                                                load_in_4bit=True,
 | 
			
		||||
                                                optimize_model=False,
 | 
			
		||||
                                                modules_to_not_convert=["lm_head"],)
 | 
			
		||||
    model = model.to('xpu')
 | 
			
		||||
    model.gradient_checkpointing_enable()
 | 
			
		||||
    model = prepare_model_for_kbit_training(model)
 | 
			
		||||
    config = LoraConfig(
 | 
			
		||||
        r=8, 
 | 
			
		||||
        lora_alpha=32, 
 | 
			
		||||
        target_modules=["q_proj", "k_proj", "v_proj"], 
 | 
			
		||||
        lora_dropout=0.05, 
 | 
			
		||||
        bias="none", 
 | 
			
		||||
        task_type="CAUSAL_LM"
 | 
			
		||||
    )
 | 
			
		||||
    model = get_peft_model(model, config)
 | 
			
		||||
    tokenizer.pad_token_id = 0
 | 
			
		||||
    tokenizer.padding_side = "left"
 | 
			
		||||
    trainer = transformers.Trainer(
 | 
			
		||||
        model=model,
 | 
			
		||||
        train_dataset=data["train"],
 | 
			
		||||
        args=transformers.TrainingArguments(
 | 
			
		||||
            per_device_train_batch_size=4,
 | 
			
		||||
            gradient_accumulation_steps= 1,
 | 
			
		||||
            warmup_steps=20,
 | 
			
		||||
            max_steps=200,
 | 
			
		||||
            learning_rate=2e-4,
 | 
			
		||||
            fp16=False, # fp16 is not supported yet
 | 
			
		||||
            logging_steps=20,
 | 
			
		||||
            output_dir="outputs",
 | 
			
		||||
            optim="adamw_hf", # paged_adamw_8bit is not supported yet
 | 
			
		||||
            # gradient_checkpointing=True, # can further reduce memory but slower
 | 
			
		||||
        ),
 | 
			
		||||
        data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
 | 
			
		||||
    )
 | 
			
		||||
    model.config.use_cache = False  # silence the warnings. Please re-enable for inference!
 | 
			
		||||
    result = trainer.train()
 | 
			
		||||
    print(result)
 | 
			
		||||
| 
						 | 
				
			
			@ -99,8 +99,9 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
def ggml_convert_low_bit(model, qtype, optimize_model=True,
 | 
			
		||||
                         convert_shape_only=False, device="cpu"):
 | 
			
		||||
    modules_to_not_convert = []  # ["lm_head"]
 | 
			
		||||
                         convert_shape_only=False, device="cpu",
 | 
			
		||||
                         modules_to_not_convert=None):
 | 
			
		||||
    modules_to_not_convert = [] if modules_to_not_convert is None else modules_to_not_convert
 | 
			
		||||
    model, has_been_replaced = _replace_with_low_bit_linear(
 | 
			
		||||
        model, qtype, modules_to_not_convert,
 | 
			
		||||
        None, convert_shape_only,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -284,8 +284,38 @@ def ggml_matmul_src1_x_src0_t(src0: torch.Tensor,
 | 
			
		|||
    return result_t
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class MatMulLowBit(torch.autograd.Function):
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def forward(ctx, A, weight):
 | 
			
		||||
        ctx.is_empty = False
 | 
			
		||||
        import linear_q4_0
 | 
			
		||||
        result = linear_q4_0.forward_new(A, weight.data, weight.qtype)
 | 
			
		||||
        if any(ctx.needs_input_grad[:2]):
 | 
			
		||||
            ctx.tensors = (A, weight)
 | 
			
		||||
        else:
 | 
			
		||||
            ctx.tensors = (None, None)
 | 
			
		||||
        return result
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def backward(ctx, grad_output):
 | 
			
		||||
        import linear_q4_0
 | 
			
		||||
        if ctx.is_empty:
 | 
			
		||||
            bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias)
 | 
			
		||||
            return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None
 | 
			
		||||
        req_gradA, _ = ctx.needs_input_grad
 | 
			
		||||
        A, weight = ctx.tensors
 | 
			
		||||
        grad_A, grad_weight = None, None
 | 
			
		||||
        if req_gradA:
 | 
			
		||||
            dequant_weight = linear_q4_0.dequant(A, weight.data, weight.qtype)
 | 
			
		||||
            grad_A = torch.matmul(grad_output, dequant_weight.reshape(weight._shape))
 | 
			
		||||
 | 
			
		||||
        return grad_A, grad_weight
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class LowBitLinear(nn.Linear):
 | 
			
		||||
    def __init__(self, input_features, output_features, qtype, bias=True):
 | 
			
		||||
    def __init__(self, input_features, output_features, qtype, bias=True,
 | 
			
		||||
                 conver_to_half=True):
 | 
			
		||||
        super().__init__(input_features, output_features, bias)
 | 
			
		||||
        self.weight = FP4Params(self.weight.data,
 | 
			
		||||
                                requires_grad=False,
 | 
			
		||||
| 
						 | 
				
			
			@ -295,6 +325,7 @@ class LowBitLinear(nn.Linear):
 | 
			
		|||
        self.weight_shape = (self.out_len, self.in_len)
 | 
			
		||||
        self.weight_length = self.out_len * self.in_len
 | 
			
		||||
        self.qtype = qtype
 | 
			
		||||
        self.conver_to_half = conver_to_half
 | 
			
		||||
 | 
			
		||||
    def forward(self, x: torch.Tensor):
 | 
			
		||||
        if self.bias is not None and self.bias.dtype != x.dtype:
 | 
			
		||||
| 
						 | 
				
			
			@ -317,10 +348,14 @@ class LowBitLinear(nn.Linear):
 | 
			
		|||
            if x_2d.is_contiguous() is False:
 | 
			
		||||
                x_2d = x_2d.contiguous()
 | 
			
		||||
            # current workaround to reduce first token latency of fp32 input
 | 
			
		||||
            if x_2d.shape[0] > 1 and x_2d.dtype == torch.float32:
 | 
			
		||||
            # sometimes fp16 cause nan and training instability
 | 
			
		||||
            # disable the conversion when training
 | 
			
		||||
            if self.conver_to_half and x_2d.shape[0] > 1 and x_2d.dtype == torch.float32:
 | 
			
		||||
                x_2d = x_2d.half()
 | 
			
		||||
            # input format of linear_q4.forward is 1: input, 2: weight
 | 
			
		||||
            result = linear_q4_0.forward_new(x_2d, x0, self.qtype)
 | 
			
		||||
            if self.training and x_2d.requires_grad:
 | 
			
		||||
                result = MatMulLowBit.apply(x_2d, self.weight)
 | 
			
		||||
            else:
 | 
			
		||||
                result = linear_q4_0.forward_new(x_2d, self.weight.data, self.weight.qtype)
 | 
			
		||||
            new_shape = x_shape[:-1] + (self.out_len,)
 | 
			
		||||
            result = result.view(new_shape)
 | 
			
		||||
            if self.bias is not None:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -108,6 +108,7 @@ class _BaseAutoModelClass:
 | 
			
		|||
        # In case it needs a second try,
 | 
			
		||||
        # `from_pretrained`` may pop items out in dict
 | 
			
		||||
        # and lead to args missing.
 | 
			
		||||
        modules_to_not_convert = kwargs.pop("modules_to_not_convert", None)
 | 
			
		||||
        _args = copy.deepcopy(args)
 | 
			
		||||
        _kwargs = copy.deepcopy(kwargs)
 | 
			
		||||
        try:
 | 
			
		||||
| 
						 | 
				
			
			@ -119,7 +120,8 @@ class _BaseAutoModelClass:
 | 
			
		|||
            model = cls.HF_Model.from_pretrained(*_args, **_kwargs)
 | 
			
		||||
            model.config.update({"bigdl_lcmu_enabled": False})
 | 
			
		||||
        model = model.to("cpu")
 | 
			
		||||
        model = ggml_convert_low_bit(model, qtype, optimize_model)
 | 
			
		||||
        model = ggml_convert_low_bit(model, qtype, optimize_model,
 | 
			
		||||
                                     modules_to_not_convert=modules_to_not_convert)
 | 
			
		||||
        model.config.update({"bigdl_transformers_low_bit": q_k})
 | 
			
		||||
        model.config.update({"tie_word_embeddings": False})
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -155,6 +157,7 @@ class _BaseAutoModelClass:
 | 
			
		|||
        import copy
 | 
			
		||||
        import os
 | 
			
		||||
 | 
			
		||||
        modules_to_not_convert = kwargs.pop("modules_to_not_convert", None)
 | 
			
		||||
        # Autofactory
 | 
			
		||||
        trust_remote_code = kwargs.pop("trust_remote_code", None)
 | 
			
		||||
        kwargs_orig = copy.deepcopy(kwargs)
 | 
			
		||||
| 
						 | 
				
			
			@ -264,7 +267,8 @@ class _BaseAutoModelClass:
 | 
			
		|||
 | 
			
		||||
        # Loading args may differ based on their usage
 | 
			
		||||
        quant_device = "meta" if bigdl_lcmu_enabled else "cpu"
 | 
			
		||||
        model = ggml_convert_low_bit(model, qtype, optimize_model, device=quant_device)
 | 
			
		||||
        model = ggml_convert_low_bit(model, qtype, optimize_model, device=quant_device,
 | 
			
		||||
                                     modules_to_not_convert=modules_to_not_convert)
 | 
			
		||||
 | 
			
		||||
        if is_sharded:
 | 
			
		||||
            loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										191
									
								
								python/llm/src/bigdl/llm/transformers/qlora.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										191
									
								
								python/llm/src/bigdl/llm/transformers/qlora.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,191 @@
 | 
			
		|||
#
 | 
			
		||||
# Copyright 2016 The BigDL Authors.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
#
 | 
			
		||||
# Some parts of this file is adapted from
 | 
			
		||||
# https://github.com/huggingface/peft/blob/v0.5.0/src/peft/tuners/lora.py
 | 
			
		||||
#
 | 
			
		||||
# coding=utf-8
 | 
			
		||||
# Copyright 2023-present the HuggingFace Inc. team.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from bigdl.llm.transformers.low_bit_linear import LowBitLinear
 | 
			
		||||
from peft.tuners.lora import LoraLayer
 | 
			
		||||
from bigdl.llm.utils.common import invalidInputError
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class LoraLowBitLinear(LowBitLinear, LoraLayer):
 | 
			
		||||
    # Lora implemented in a dense layer
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        adapter_name,
 | 
			
		||||
        in_features,
 | 
			
		||||
        out_features,
 | 
			
		||||
        r: int = 0,
 | 
			
		||||
        lora_alpha: int = 1,
 | 
			
		||||
        lora_dropout: float = 0.0,
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    ):
 | 
			
		||||
        LowBitLinear.__init__(
 | 
			
		||||
            self,
 | 
			
		||||
            in_features,
 | 
			
		||||
            out_features,
 | 
			
		||||
            qtype=kwargs.get("qtype"),
 | 
			
		||||
            bias=kwargs.get("bias", True),
 | 
			
		||||
            conver_to_half=False,
 | 
			
		||||
        )
 | 
			
		||||
        LoraLayer.__init__(self, in_features=in_features, out_features=out_features)
 | 
			
		||||
 | 
			
		||||
        # Freezing the pre-trained weight matrix
 | 
			
		||||
        self.weight.requires_grad = False
 | 
			
		||||
 | 
			
		||||
        init_lora_weights = kwargs.pop("init_lora_weights", True)
 | 
			
		||||
        self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
 | 
			
		||||
        self.active_adapter = adapter_name
 | 
			
		||||
 | 
			
		||||
    def forward(self, x: torch.Tensor):
 | 
			
		||||
        result = super().forward(x)
 | 
			
		||||
 | 
			
		||||
        if self.disable_adapters or self.active_adapter not in self.lora_A.keys():
 | 
			
		||||
            return result
 | 
			
		||||
        elif self.r[self.active_adapter] > 0:
 | 
			
		||||
            result = result.clone()
 | 
			
		||||
            if not torch.is_autocast_enabled():
 | 
			
		||||
                expected_dtype = result.dtype
 | 
			
		||||
                x = x.to(self.lora_A[self.active_adapter].weight.dtype)
 | 
			
		||||
                output = (
 | 
			
		||||
                    self.lora_B[self.active_adapter](
 | 
			
		||||
                        self.lora_A[self.active_adapter](self.lora_dropout[self.active_adapter](x))
 | 
			
		||||
                    ).to(expected_dtype)
 | 
			
		||||
                    * self.scaling[self.active_adapter]
 | 
			
		||||
                )
 | 
			
		||||
            else:
 | 
			
		||||
                output = (
 | 
			
		||||
                    self.lora_B[self.active_adapter](
 | 
			
		||||
                        self.lora_A[self.active_adapter](self.lora_dropout[self.active_adapter](x))
 | 
			
		||||
                    )
 | 
			
		||||
                    * self.scaling[self.active_adapter]
 | 
			
		||||
                )
 | 
			
		||||
            result += output
 | 
			
		||||
        return result
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@staticmethod
 | 
			
		||||
def _create_new_module(lora_config, adapter_name, target, **kwargs):
 | 
			
		||||
 | 
			
		||||
    bias = kwargs.pop("bias", False)
 | 
			
		||||
 | 
			
		||||
    if isinstance(target, LowBitLinear):
 | 
			
		||||
        low_bit_kwargs = kwargs.copy()
 | 
			
		||||
        low_bit_kwargs.update(
 | 
			
		||||
            {
 | 
			
		||||
                "qtype": target.qtype,
 | 
			
		||||
            }
 | 
			
		||||
        )
 | 
			
		||||
        new_module = LoraLowBitLinear(adapter_name,
 | 
			
		||||
                                      target.in_features,
 | 
			
		||||
                                      target.out_features,
 | 
			
		||||
                                      bias=bias,
 | 
			
		||||
                                      **low_bit_kwargs)
 | 
			
		||||
    else:
 | 
			
		||||
        invalidInputError(False,
 | 
			
		||||
                          f"Target module {target} is not supported. "
 | 
			
		||||
                          f"Currently, only `LowBitLinear` are supported.")
 | 
			
		||||
 | 
			
		||||
    return new_module
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
from peft.tuners.lora import LoraModel
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_peft_model(*args, **kwargs):
 | 
			
		||||
    old_create_new_module = LoraModel._create_new_module
 | 
			
		||||
    LoraModel._create_new_module = _create_new_module
 | 
			
		||||
    try:
 | 
			
		||||
        from peft import get_peft_model as get_peft_model_original
 | 
			
		||||
        model = get_peft_model_original(*args, **kwargs)
 | 
			
		||||
    finally:
 | 
			
		||||
        LoraModel._create_new_module = old_create_new_module
 | 
			
		||||
 | 
			
		||||
    return model
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def prepare_model_for_kbit_training(model, use_gradient_checkpointing=True):
 | 
			
		||||
    r"""
 | 
			
		||||
    This method wraps the entire protocol for preparing a model before running a training.
 | 
			
		||||
    This includes:
 | 
			
		||||
        1- Cast the layernorm in fp32
 | 
			
		||||
        2- making output embedding layer require grads
 | 
			
		||||
        3- Add the upcasting of the lm head to fp32
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        model, (`transformers.PreTrainedModel`):
 | 
			
		||||
            The loaded model from `transformers`
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    is_gptq_quantized = getattr(model, "quantization_method", None) == "gptq"
 | 
			
		||||
    for name, param in model.named_parameters():
 | 
			
		||||
        # freeze base model's layers
 | 
			
		||||
        param.requires_grad = False
 | 
			
		||||
 | 
			
		||||
    if not is_gptq_quantized:
 | 
			
		||||
        # cast all non INT8 parameters to fp32
 | 
			
		||||
        for param in model.parameters():
 | 
			
		||||
            if (param.dtype == torch.float16) or (param.dtype == torch.bfloat16):
 | 
			
		||||
                param.data = param.data.to(torch.float32)
 | 
			
		||||
 | 
			
		||||
    if use_gradient_checkpointing:
 | 
			
		||||
        # For backward compatibility
 | 
			
		||||
        if hasattr(model, "enable_input_require_grads"):
 | 
			
		||||
            model.enable_input_require_grads()
 | 
			
		||||
        else:
 | 
			
		||||
 | 
			
		||||
            def make_inputs_require_grad(module, input, output):
 | 
			
		||||
                output.requires_grad_(True)
 | 
			
		||||
 | 
			
		||||
            model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
 | 
			
		||||
 | 
			
		||||
        # enable gradient checkpointing for memory efficiency
 | 
			
		||||
        model.gradient_checkpointing_enable()
 | 
			
		||||
 | 
			
		||||
    return model
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PeftModel:
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def from_pretrained(*args,
 | 
			
		||||
                        **kwargs):
 | 
			
		||||
        old_create_new_module = LoraModel._create_new_module
 | 
			
		||||
        LoraModel._create_new_module = _create_new_module
 | 
			
		||||
        from peft import PeftModel
 | 
			
		||||
        try:
 | 
			
		||||
            model = PeftModel.from_pretrained(*args, **kwargs)
 | 
			
		||||
        finally:
 | 
			
		||||
            LoraModel._create_new_module = old_create_new_module
 | 
			
		||||
 | 
			
		||||
        return model
 | 
			
		||||
		Loading…
	
		Reference in a new issue