refactor xpu linear forward (#12768)

This commit is contained in:
Yishuo Wang 2025-02-05 17:40:38 +08:00 committed by GitHub
parent 413d6c2b66
commit 0237ffb302
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 32 additions and 82 deletions

View file

@ -500,16 +500,16 @@ class MatMulLowBit(torch.autograd.Function):
@staticmethod
@custom_fwd
def forward(ctx, A, weight, input_seq_size):
def forward(ctx, A, weight, output_size):
ctx.is_empty = False
import xe_linear
if weight.qtype == NF4:
result = xe_linear.forward_new(A,
weight.data.view(torch.uint8),
weight.qtype,
input_seq_size)
output_size)
else:
result = xe_linear.forward_new(A, weight.data, weight.qtype, input_seq_size)
result = xe_linear.forward_new(A, weight.data, weight.qtype, output_size)
if any(ctx.needs_input_grad[:2]):
ctx.tensors = (A, weight)
else:
@ -627,89 +627,50 @@ class LowBitLinear(nn.Linear):
if self.optimize_lm_head:
x = reshape_lm_head_input(x)
# [batch, input_num, in_len]
# input_num == token num for Transformer
x_shape = x.shape
# Output shape, e.g., [batch, input_num, out_len]
new_shape = x_shape[:-1] + (self.out_len,)
# [batch, seq_len, in_len] -> [batch, seq_len, out_len]
new_shape = x.shape[:-1] + (self.out_len,)
# Activation is empty tensor, e.g., [1, 0, 4096]
if 0 in x_shape:
if 0 in x.shape:
# return empty tensor with output shape, x.dtype and x.device
return torch.empty(new_shape, dtype=x.dtype, device=x.device)
x_2d = x.contiguous().view(-1, x_shape[-1])
if self.act_order:
x_2d = x_2d[:, self.g_idx_map]
# x0 for weight
x0 = self.weight.data
x = x[..., self.g_idx_map]
if x0.device.type == "xpu":
# GPU logic
try:
import xe_linear
from ipex_llm.transformers.models.utils import use_xmx
except ModuleNotFoundError:
invalidInputError(False,
"Please `pip install bigdl_core_xe` first.")
x_2d = x.contiguous().view(-1, x.shape[-1])
if x_2d.is_contiguous() is False:
x_2d = x_2d.contiguous()
if len(x_shape) == 3:
input_seq_size = x_shape[1]
elif len(x_shape) < 3:
input_seq_size = 1
if is_training:
# training path
if x_2d.requires_grad:
result = MatMulLowBit.apply(x_2d, self.weight, input_seq_size)
if self.weight.device.type == "xpu":
if is_training and x_2d.requires_grad:
result = MatMulLowBit.apply(x_2d, self.weight, self.out_len)
else:
if self.weight.qtype == NF4:
result = xe_linear.forward_new(x_2d,
self.weight.data.view(torch.uint8),
self.weight.qtype,
input_seq_size)
else:
result = xe_linear.forward_new(x_2d,
self.weight.data,
self.weight.qtype,
input_seq_size)
else:
# inference path
# current workaround to reduce first token latency of fp32 input
# sometimes fp16 cause nan and training instability
# disable the conversion when training
# TODO: may modify the input length condition for empty cache.
do_empty_cache = self.low_memory_mode and x_2d.shape[0] >= 1024
if do_empty_cache:
torch.xpu.empty_cache()
if self.qtype == NF4:
w = self.weight.data.view(torch.uint8)
else:
w = self.weight.data
if use_batch_forward(x_2d, self.weight.qtype, self.out_len):
import xe_batch
result = xe_batch.batch_forward(x_2d, self.weight.data, self.weight.qtype)
elif (
self.conver_to_half
and x_2d.shape[0] > 1
and x_2d.dtype == torch.float32
and not use_xmx(x_2d, self.weight.qtype)
):
result = xe_batch.batch_forward(x_2d, w, self.qtype)
elif not is_training and self.conver_to_half \
and x_2d.shape[0] > 1 and x_2d.dtype == torch.float:
import xe_linear
x_2d = x_2d.half()
result = xe_linear.forward_new(x_2d, self.weight.data,
self.weight.qtype, input_seq_size)
result = xe_linear.forward_new(x_2d, w, self.qtype, self.out_len)
result = result.to(x.dtype)
else:
if self.weight.qtype == NF4:
result = xe_linear.forward_new(x_2d, self.weight.data.view(torch.uint8),
self.weight.qtype, input_seq_size)
else:
result = xe_linear.forward_new(x_2d, self.weight.data,
self.weight.qtype, input_seq_size)
import xe_linear
result = xe_linear.forward_new(x_2d, w, self.qtype, self.out_len)
if do_empty_cache:
torch.xpu.empty_cache()
result = result.view(new_shape)
if self.mp_group is not None:
if get_use_vllm():
result = self.mp_group.all_reduce(result)
@ -718,6 +679,7 @@ class LowBitLinear(nn.Linear):
dist.inference_all_reduce(result, group=self.mp_group)
else:
invalidInputError(False, "mp_group is not None, but no supported backend found")
if self.bias is not None:
result += self.bias
else:
@ -731,7 +693,7 @@ class LowBitLinear(nn.Linear):
result = MatMulLowBitCPU.apply(x, self.weight)
else:
from ipex_llm.utils.isa_checker import is_server, is_spr
x0 = self.weight.data
# convert if necessary, and compute a linear result
if is_server() and (not is_spr()) and \
self.qtype == SYM_INT4 and x_2d.shape[0] >= TORCH_LINEAR_THRESHOLD:

View file

@ -259,19 +259,6 @@ def mlp_fusion_check(x, qtype, training):
return True
def use_xmx(x: torch.Tensor, qtype: int):
device = get_xpu_device_name(x.device)
return (
device in ["arc", "pvc"]
and qtype in [SYM_INT4, SYM_INT8, FP8E4, FP8E5, WOQ_INT4]
and (
(device == "pvc" and 1 < x.size(0) <= 16)
or
(device != "pvc" and 1 < x.size(0) <= 64)
)
)
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:

View file

@ -20,9 +20,10 @@ import xe_batch
import xe_addons
# @torch.library.register_fake("ipex_llm::forward_new")
# def _(x, weight, qtype, input_size):
# return ???
@torch.library.register_fake("ipex_llm::forward_new")
def _(x, weight, qtype, output_size):
return torch.empty([x.size(0), output_size],
dtype=x.dtype, device=x.device)
# @torch.library.register_fake("ipex_llm::dequant")