LLM: optimize QLoRA by updating lora convert logic (#9372)
* update convert logic of qlora * update * refactor and further improve performance * fix style * meet code review
This commit is contained in:
parent
54d95e4907
commit
bfca76dfa7
3 changed files with 21 additions and 5 deletions
|
|
@ -336,8 +336,6 @@ class MatMulLowBit(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@custom_fwd
|
@custom_fwd
|
||||||
def forward(ctx, A, weight, input_seq_size):
|
def forward(ctx, A, weight, input_seq_size):
|
||||||
if torch.xpu.is_autocast_xpu_enabled():
|
|
||||||
A = A.to(torch.xpu.get_autocast_xpu_dtype())
|
|
||||||
ctx.is_empty = False
|
ctx.is_empty = False
|
||||||
import linear_q4_0
|
import linear_q4_0
|
||||||
result = linear_q4_0.forward_new(A, weight.data, weight.qtype, input_seq_size)
|
result = linear_q4_0.forward_new(A, weight.data, weight.qtype, input_seq_size)
|
||||||
|
|
@ -448,8 +446,6 @@ class LowBitLinear(nn.Linear):
|
||||||
input_seq_size)
|
input_seq_size)
|
||||||
result = result.to(x.dtype)
|
result = result.to(x.dtype)
|
||||||
else:
|
else:
|
||||||
if torch.xpu.is_autocast_xpu_enabled():
|
|
||||||
x_2d = x_2d.to(torch.xpu.get_autocast_xpu_dtype())
|
|
||||||
result = linear_q4_0.forward_new(x_2d, self.weight.data, self.weight.qtype,
|
result = linear_q4_0.forward_new(x_2d, self.weight.data, self.weight.qtype,
|
||||||
input_seq_size)
|
input_seq_size)
|
||||||
new_shape = x_shape[:-1] + (self.out_len,)
|
new_shape = x_shape[:-1] + (self.out_len,)
|
||||||
|
|
|
||||||
|
|
@ -52,6 +52,7 @@ import torch
|
||||||
from bigdl.llm.transformers.low_bit_linear import LowBitLinear
|
from bigdl.llm.transformers.low_bit_linear import LowBitLinear
|
||||||
from peft.tuners.lora import LoraLayer
|
from peft.tuners.lora import LoraLayer
|
||||||
from bigdl.llm.utils.common import invalidInputError
|
from bigdl.llm.utils.common import invalidInputError
|
||||||
|
from bigdl.llm.transformers.utils import get_autocast_dtype
|
||||||
import functools
|
import functools
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -85,13 +86,16 @@ class LoraLowBitLinear(LowBitLinear, LoraLayer):
|
||||||
self.active_adapter = adapter_name
|
self.active_adapter = adapter_name
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
def forward(self, x: torch.Tensor):
|
||||||
|
autocast_dtype = get_autocast_dtype(x)
|
||||||
|
if autocast_dtype is not None:
|
||||||
|
x = x.to(autocast_dtype)
|
||||||
result = super().forward(x)
|
result = super().forward(x)
|
||||||
|
|
||||||
if self.disable_adapters or self.active_adapter not in self.lora_A.keys():
|
if self.disable_adapters or self.active_adapter not in self.lora_A.keys():
|
||||||
return result
|
return result
|
||||||
elif self.r[self.active_adapter] > 0:
|
elif self.r[self.active_adapter] > 0:
|
||||||
result = result.clone()
|
result = result.clone()
|
||||||
if not torch.is_autocast_enabled():
|
if autocast_dtype is None:
|
||||||
expected_dtype = result.dtype
|
expected_dtype = result.dtype
|
||||||
x = x.to(self.lora_A[self.active_adapter].weight.dtype)
|
x = x.to(self.lora_A[self.active_adapter].weight.dtype)
|
||||||
output = (
|
output = (
|
||||||
|
|
|
||||||
|
|
@ -133,3 +133,19 @@ def fix_key(key):
|
||||||
if "gamma" in key:
|
if "gamma" in key:
|
||||||
return key.replace("gamma", "weight")
|
return key.replace("gamma", "weight")
|
||||||
return key
|
return key
|
||||||
|
|
||||||
|
|
||||||
|
def get_autocast_dtype(x):
|
||||||
|
if x.device.type == "xpu":
|
||||||
|
if torch.xpu.is_autocast_xpu_enabled():
|
||||||
|
return torch.xpu.get_autocast_xpu_dtype()
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
elif x.device.type == "cpu":
|
||||||
|
if torch.is_autocast_enabled():
|
||||||
|
return torch.get_autocast_cpu_dtype()
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
invalidInputError(False,
|
||||||
|
f"Device {x.device} is not supported.")
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue