LLM: optimize QLoRA by updating lora convert logic (#9372)

* update convert logic of qlora

* update

* refactor and further improve performance

* fix style

* meet code review
This commit is contained in:
Ruonan Wang 2023-11-08 17:46:49 +08:00 committed by GitHub
parent 54d95e4907
commit bfca76dfa7
3 changed files with 21 additions and 5 deletions

View file

@ -336,8 +336,6 @@ class MatMulLowBit(torch.autograd.Function):
@staticmethod
@custom_fwd
def forward(ctx, A, weight, input_seq_size):
if torch.xpu.is_autocast_xpu_enabled():
A = A.to(torch.xpu.get_autocast_xpu_dtype())
ctx.is_empty = False
import linear_q4_0
result = linear_q4_0.forward_new(A, weight.data, weight.qtype, input_seq_size)
@ -448,8 +446,6 @@ class LowBitLinear(nn.Linear):
input_seq_size)
result = result.to(x.dtype)
else:
if torch.xpu.is_autocast_xpu_enabled():
x_2d = x_2d.to(torch.xpu.get_autocast_xpu_dtype())
result = linear_q4_0.forward_new(x_2d, self.weight.data, self.weight.qtype,
input_seq_size)
new_shape = x_shape[:-1] + (self.out_len,)

View file

@ -52,6 +52,7 @@ import torch
from bigdl.llm.transformers.low_bit_linear import LowBitLinear
from peft.tuners.lora import LoraLayer
from bigdl.llm.utils.common import invalidInputError
from bigdl.llm.transformers.utils import get_autocast_dtype
import functools
@ -85,13 +86,16 @@ class LoraLowBitLinear(LowBitLinear, LoraLayer):
self.active_adapter = adapter_name
def forward(self, x: torch.Tensor):
autocast_dtype = get_autocast_dtype(x)
if autocast_dtype is not None:
x = x.to(autocast_dtype)
result = super().forward(x)
if self.disable_adapters or self.active_adapter not in self.lora_A.keys():
return result
elif self.r[self.active_adapter] > 0:
result = result.clone()
if not torch.is_autocast_enabled():
if autocast_dtype is None:
expected_dtype = result.dtype
x = x.to(self.lora_A[self.active_adapter].weight.dtype)
output = (

View file

@ -133,3 +133,19 @@ def fix_key(key):
if "gamma" in key:
return key.replace("gamma", "weight")
return key
def get_autocast_dtype(x):
if x.device.type == "xpu":
if torch.xpu.is_autocast_xpu_enabled():
return torch.xpu.get_autocast_xpu_dtype()
else:
return None
elif x.device.type == "cpu":
if torch.is_autocast_enabled():
return torch.get_autocast_cpu_dtype()
else:
return None
else:
invalidInputError(False,
f"Device {x.device} is not supported.")