refactor xpu linear forward (#12768)
This commit is contained in:
parent
413d6c2b66
commit
0237ffb302
3 changed files with 32 additions and 82 deletions
|
|
@ -500,16 +500,16 @@ class MatMulLowBit(torch.autograd.Function):
|
|||
|
||||
@staticmethod
|
||||
@custom_fwd
|
||||
def forward(ctx, A, weight, input_seq_size):
|
||||
def forward(ctx, A, weight, output_size):
|
||||
ctx.is_empty = False
|
||||
import xe_linear
|
||||
if weight.qtype == NF4:
|
||||
result = xe_linear.forward_new(A,
|
||||
weight.data.view(torch.uint8),
|
||||
weight.qtype,
|
||||
input_seq_size)
|
||||
output_size)
|
||||
else:
|
||||
result = xe_linear.forward_new(A, weight.data, weight.qtype, input_seq_size)
|
||||
result = xe_linear.forward_new(A, weight.data, weight.qtype, output_size)
|
||||
if any(ctx.needs_input_grad[:2]):
|
||||
ctx.tensors = (A, weight)
|
||||
else:
|
||||
|
|
@ -627,89 +627,50 @@ class LowBitLinear(nn.Linear):
|
|||
if self.optimize_lm_head:
|
||||
x = reshape_lm_head_input(x)
|
||||
|
||||
# [batch, input_num, in_len]
|
||||
# input_num == token num for Transformer
|
||||
x_shape = x.shape
|
||||
# Output shape, e.g., [batch, input_num, out_len]
|
||||
new_shape = x_shape[:-1] + (self.out_len,)
|
||||
# [batch, seq_len, in_len] -> [batch, seq_len, out_len]
|
||||
new_shape = x.shape[:-1] + (self.out_len,)
|
||||
|
||||
# Activation is empty tensor, e.g., [1, 0, 4096]
|
||||
if 0 in x_shape:
|
||||
if 0 in x.shape:
|
||||
# return empty tensor with output shape, x.dtype and x.device
|
||||
return torch.empty(new_shape, dtype=x.dtype, device=x.device)
|
||||
|
||||
x_2d = x.contiguous().view(-1, x_shape[-1])
|
||||
|
||||
if self.act_order:
|
||||
x_2d = x_2d[:, self.g_idx_map]
|
||||
# x0 for weight
|
||||
x0 = self.weight.data
|
||||
x = x[..., self.g_idx_map]
|
||||
|
||||
if x0.device.type == "xpu":
|
||||
# GPU logic
|
||||
try:
|
||||
import xe_linear
|
||||
from ipex_llm.transformers.models.utils import use_xmx
|
||||
except ModuleNotFoundError:
|
||||
invalidInputError(False,
|
||||
"Please `pip install bigdl_core_xe` first.")
|
||||
x_2d = x.contiguous().view(-1, x.shape[-1])
|
||||
|
||||
if x_2d.is_contiguous() is False:
|
||||
x_2d = x_2d.contiguous()
|
||||
|
||||
if len(x_shape) == 3:
|
||||
input_seq_size = x_shape[1]
|
||||
elif len(x_shape) < 3:
|
||||
input_seq_size = 1
|
||||
|
||||
if is_training:
|
||||
# training path
|
||||
if x_2d.requires_grad:
|
||||
result = MatMulLowBit.apply(x_2d, self.weight, input_seq_size)
|
||||
if self.weight.device.type == "xpu":
|
||||
if is_training and x_2d.requires_grad:
|
||||
result = MatMulLowBit.apply(x_2d, self.weight, self.out_len)
|
||||
else:
|
||||
if self.weight.qtype == NF4:
|
||||
result = xe_linear.forward_new(x_2d,
|
||||
self.weight.data.view(torch.uint8),
|
||||
self.weight.qtype,
|
||||
input_seq_size)
|
||||
else:
|
||||
result = xe_linear.forward_new(x_2d,
|
||||
self.weight.data,
|
||||
self.weight.qtype,
|
||||
input_seq_size)
|
||||
else:
|
||||
# inference path
|
||||
# current workaround to reduce first token latency of fp32 input
|
||||
# sometimes fp16 cause nan and training instability
|
||||
# disable the conversion when training
|
||||
# TODO: may modify the input length condition for empty cache.
|
||||
do_empty_cache = self.low_memory_mode and x_2d.shape[0] >= 1024
|
||||
if do_empty_cache:
|
||||
torch.xpu.empty_cache()
|
||||
|
||||
if self.qtype == NF4:
|
||||
w = self.weight.data.view(torch.uint8)
|
||||
else:
|
||||
w = self.weight.data
|
||||
|
||||
if use_batch_forward(x_2d, self.weight.qtype, self.out_len):
|
||||
import xe_batch
|
||||
result = xe_batch.batch_forward(x_2d, self.weight.data, self.weight.qtype)
|
||||
elif (
|
||||
self.conver_to_half
|
||||
and x_2d.shape[0] > 1
|
||||
and x_2d.dtype == torch.float32
|
||||
and not use_xmx(x_2d, self.weight.qtype)
|
||||
):
|
||||
result = xe_batch.batch_forward(x_2d, w, self.qtype)
|
||||
elif not is_training and self.conver_to_half \
|
||||
and x_2d.shape[0] > 1 and x_2d.dtype == torch.float:
|
||||
import xe_linear
|
||||
x_2d = x_2d.half()
|
||||
result = xe_linear.forward_new(x_2d, self.weight.data,
|
||||
self.weight.qtype, input_seq_size)
|
||||
result = xe_linear.forward_new(x_2d, w, self.qtype, self.out_len)
|
||||
result = result.to(x.dtype)
|
||||
else:
|
||||
if self.weight.qtype == NF4:
|
||||
result = xe_linear.forward_new(x_2d, self.weight.data.view(torch.uint8),
|
||||
self.weight.qtype, input_seq_size)
|
||||
else:
|
||||
result = xe_linear.forward_new(x_2d, self.weight.data,
|
||||
self.weight.qtype, input_seq_size)
|
||||
import xe_linear
|
||||
result = xe_linear.forward_new(x_2d, w, self.qtype, self.out_len)
|
||||
|
||||
if do_empty_cache:
|
||||
torch.xpu.empty_cache()
|
||||
|
||||
result = result.view(new_shape)
|
||||
|
||||
if self.mp_group is not None:
|
||||
if get_use_vllm():
|
||||
result = self.mp_group.all_reduce(result)
|
||||
|
|
@ -718,6 +679,7 @@ class LowBitLinear(nn.Linear):
|
|||
dist.inference_all_reduce(result, group=self.mp_group)
|
||||
else:
|
||||
invalidInputError(False, "mp_group is not None, but no supported backend found")
|
||||
|
||||
if self.bias is not None:
|
||||
result += self.bias
|
||||
else:
|
||||
|
|
@ -731,7 +693,7 @@ class LowBitLinear(nn.Linear):
|
|||
result = MatMulLowBitCPU.apply(x, self.weight)
|
||||
else:
|
||||
from ipex_llm.utils.isa_checker import is_server, is_spr
|
||||
|
||||
x0 = self.weight.data
|
||||
# convert if necessary, and compute a linear result
|
||||
if is_server() and (not is_spr()) and \
|
||||
self.qtype == SYM_INT4 and x_2d.shape[0] >= TORCH_LINEAR_THRESHOLD:
|
||||
|
|
|
|||
|
|
@ -259,19 +259,6 @@ def mlp_fusion_check(x, qtype, training):
|
|||
return True
|
||||
|
||||
|
||||
def use_xmx(x: torch.Tensor, qtype: int):
|
||||
device = get_xpu_device_name(x.device)
|
||||
return (
|
||||
device in ["arc", "pvc"]
|
||||
and qtype in [SYM_INT4, SYM_INT8, FP8E4, FP8E5, WOQ_INT4]
|
||||
and (
|
||||
(device == "pvc" and 1 < x.size(0) <= 16)
|
||||
or
|
||||
(device != "pvc" and 1 < x.size(0) <= 64)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||
if n_rep == 1:
|
||||
|
|
|
|||
|
|
@ -20,9 +20,10 @@ import xe_batch
|
|||
import xe_addons
|
||||
|
||||
|
||||
# @torch.library.register_fake("ipex_llm::forward_new")
|
||||
# def _(x, weight, qtype, input_size):
|
||||
# return ???
|
||||
@torch.library.register_fake("ipex_llm::forward_new")
|
||||
def _(x, weight, qtype, output_size):
|
||||
return torch.empty([x.size(0), output_size],
|
||||
dtype=x.dtype, device=x.device)
|
||||
|
||||
|
||||
# @torch.library.register_fake("ipex_llm::dequant")
|
||||
|
|
|
|||
Loading…
Reference in a new issue