LLM: optimize QLoRA by updating lora convert logic (#9372)
* update convert logic of qlora * update * refactor and further improve performance * fix style * meet code review
This commit is contained in:
parent
54d95e4907
commit
bfca76dfa7
3 changed files with 21 additions and 5 deletions
|
|
@ -336,8 +336,6 @@ class MatMulLowBit(torch.autograd.Function):
|
|||
@staticmethod
|
||||
@custom_fwd
|
||||
def forward(ctx, A, weight, input_seq_size):
|
||||
if torch.xpu.is_autocast_xpu_enabled():
|
||||
A = A.to(torch.xpu.get_autocast_xpu_dtype())
|
||||
ctx.is_empty = False
|
||||
import linear_q4_0
|
||||
result = linear_q4_0.forward_new(A, weight.data, weight.qtype, input_seq_size)
|
||||
|
|
@ -448,8 +446,6 @@ class LowBitLinear(nn.Linear):
|
|||
input_seq_size)
|
||||
result = result.to(x.dtype)
|
||||
else:
|
||||
if torch.xpu.is_autocast_xpu_enabled():
|
||||
x_2d = x_2d.to(torch.xpu.get_autocast_xpu_dtype())
|
||||
result = linear_q4_0.forward_new(x_2d, self.weight.data, self.weight.qtype,
|
||||
input_seq_size)
|
||||
new_shape = x_shape[:-1] + (self.out_len,)
|
||||
|
|
|
|||
|
|
@ -52,6 +52,7 @@ import torch
|
|||
from bigdl.llm.transformers.low_bit_linear import LowBitLinear
|
||||
from peft.tuners.lora import LoraLayer
|
||||
from bigdl.llm.utils.common import invalidInputError
|
||||
from bigdl.llm.transformers.utils import get_autocast_dtype
|
||||
import functools
|
||||
|
||||
|
||||
|
|
@ -85,13 +86,16 @@ class LoraLowBitLinear(LowBitLinear, LoraLayer):
|
|||
self.active_adapter = adapter_name
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
autocast_dtype = get_autocast_dtype(x)
|
||||
if autocast_dtype is not None:
|
||||
x = x.to(autocast_dtype)
|
||||
result = super().forward(x)
|
||||
|
||||
if self.disable_adapters or self.active_adapter not in self.lora_A.keys():
|
||||
return result
|
||||
elif self.r[self.active_adapter] > 0:
|
||||
result = result.clone()
|
||||
if not torch.is_autocast_enabled():
|
||||
if autocast_dtype is None:
|
||||
expected_dtype = result.dtype
|
||||
x = x.to(self.lora_A[self.active_adapter].weight.dtype)
|
||||
output = (
|
||||
|
|
|
|||
|
|
@ -133,3 +133,19 @@ def fix_key(key):
|
|||
if "gamma" in key:
|
||||
return key.replace("gamma", "weight")
|
||||
return key
|
||||
|
||||
|
||||
def get_autocast_dtype(x):
|
||||
if x.device.type == "xpu":
|
||||
if torch.xpu.is_autocast_xpu_enabled():
|
||||
return torch.xpu.get_autocast_xpu_dtype()
|
||||
else:
|
||||
return None
|
||||
elif x.device.type == "cpu":
|
||||
if torch.is_autocast_enabled():
|
||||
return torch.get_autocast_cpu_dtype()
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
invalidInputError(False,
|
||||
f"Device {x.device} is not supported.")
|
||||
|
|
|
|||
Loading…
Reference in a new issue