refactor xpu linear forward (#12768)

This commit is contained in:
Yishuo Wang 2025-02-05 17:40:38 +08:00 committed by GitHub
parent 413d6c2b66
commit 0237ffb302
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 32 additions and 82 deletions

View file

@ -500,16 +500,16 @@ class MatMulLowBit(torch.autograd.Function):
@staticmethod @staticmethod
@custom_fwd @custom_fwd
def forward(ctx, A, weight, input_seq_size): def forward(ctx, A, weight, output_size):
ctx.is_empty = False ctx.is_empty = False
import xe_linear import xe_linear
if weight.qtype == NF4: if weight.qtype == NF4:
result = xe_linear.forward_new(A, result = xe_linear.forward_new(A,
weight.data.view(torch.uint8), weight.data.view(torch.uint8),
weight.qtype, weight.qtype,
input_seq_size) output_size)
else: else:
result = xe_linear.forward_new(A, weight.data, weight.qtype, input_seq_size) result = xe_linear.forward_new(A, weight.data, weight.qtype, output_size)
if any(ctx.needs_input_grad[:2]): if any(ctx.needs_input_grad[:2]):
ctx.tensors = (A, weight) ctx.tensors = (A, weight)
else: else:
@ -627,89 +627,50 @@ class LowBitLinear(nn.Linear):
if self.optimize_lm_head: if self.optimize_lm_head:
x = reshape_lm_head_input(x) x = reshape_lm_head_input(x)
# [batch, input_num, in_len] # [batch, seq_len, in_len] -> [batch, seq_len, out_len]
# input_num == token num for Transformer new_shape = x.shape[:-1] + (self.out_len,)
x_shape = x.shape
# Output shape, e.g., [batch, input_num, out_len]
new_shape = x_shape[:-1] + (self.out_len,)
# Activation is empty tensor, e.g., [1, 0, 4096] # Activation is empty tensor, e.g., [1, 0, 4096]
if 0 in x_shape: if 0 in x.shape:
# return empty tensor with output shape, x.dtype and x.device # return empty tensor with output shape, x.dtype and x.device
return torch.empty(new_shape, dtype=x.dtype, device=x.device) return torch.empty(new_shape, dtype=x.dtype, device=x.device)
x_2d = x.contiguous().view(-1, x_shape[-1])
if self.act_order: if self.act_order:
x_2d = x_2d[:, self.g_idx_map] x = x[..., self.g_idx_map]
# x0 for weight
x0 = self.weight.data
if x0.device.type == "xpu": x_2d = x.contiguous().view(-1, x.shape[-1])
# GPU logic
try:
import xe_linear
from ipex_llm.transformers.models.utils import use_xmx
except ModuleNotFoundError:
invalidInputError(False,
"Please `pip install bigdl_core_xe` first.")
if x_2d.is_contiguous() is False: if self.weight.device.type == "xpu":
x_2d = x_2d.contiguous() if is_training and x_2d.requires_grad:
result = MatMulLowBit.apply(x_2d, self.weight, self.out_len)
if len(x_shape) == 3:
input_seq_size = x_shape[1]
elif len(x_shape) < 3:
input_seq_size = 1
if is_training:
# training path
if x_2d.requires_grad:
result = MatMulLowBit.apply(x_2d, self.weight, input_seq_size)
else: else:
if self.weight.qtype == NF4:
result = xe_linear.forward_new(x_2d,
self.weight.data.view(torch.uint8),
self.weight.qtype,
input_seq_size)
else:
result = xe_linear.forward_new(x_2d,
self.weight.data,
self.weight.qtype,
input_seq_size)
else:
# inference path
# current workaround to reduce first token latency of fp32 input
# sometimes fp16 cause nan and training instability
# disable the conversion when training
# TODO: may modify the input length condition for empty cache.
do_empty_cache = self.low_memory_mode and x_2d.shape[0] >= 1024 do_empty_cache = self.low_memory_mode and x_2d.shape[0] >= 1024
if do_empty_cache: if do_empty_cache:
torch.xpu.empty_cache() torch.xpu.empty_cache()
if self.qtype == NF4:
w = self.weight.data.view(torch.uint8)
else:
w = self.weight.data
if use_batch_forward(x_2d, self.weight.qtype, self.out_len): if use_batch_forward(x_2d, self.weight.qtype, self.out_len):
import xe_batch import xe_batch
result = xe_batch.batch_forward(x_2d, self.weight.data, self.weight.qtype) result = xe_batch.batch_forward(x_2d, w, self.qtype)
elif ( elif not is_training and self.conver_to_half \
self.conver_to_half and x_2d.shape[0] > 1 and x_2d.dtype == torch.float:
and x_2d.shape[0] > 1 import xe_linear
and x_2d.dtype == torch.float32
and not use_xmx(x_2d, self.weight.qtype)
):
x_2d = x_2d.half() x_2d = x_2d.half()
result = xe_linear.forward_new(x_2d, self.weight.data, result = xe_linear.forward_new(x_2d, w, self.qtype, self.out_len)
self.weight.qtype, input_seq_size)
result = result.to(x.dtype) result = result.to(x.dtype)
else: else:
if self.weight.qtype == NF4: import xe_linear
result = xe_linear.forward_new(x_2d, self.weight.data.view(torch.uint8), result = xe_linear.forward_new(x_2d, w, self.qtype, self.out_len)
self.weight.qtype, input_seq_size)
else:
result = xe_linear.forward_new(x_2d, self.weight.data,
self.weight.qtype, input_seq_size)
if do_empty_cache: if do_empty_cache:
torch.xpu.empty_cache() torch.xpu.empty_cache()
result = result.view(new_shape) result = result.view(new_shape)
if self.mp_group is not None: if self.mp_group is not None:
if get_use_vllm(): if get_use_vllm():
result = self.mp_group.all_reduce(result) result = self.mp_group.all_reduce(result)
@ -718,6 +679,7 @@ class LowBitLinear(nn.Linear):
dist.inference_all_reduce(result, group=self.mp_group) dist.inference_all_reduce(result, group=self.mp_group)
else: else:
invalidInputError(False, "mp_group is not None, but no supported backend found") invalidInputError(False, "mp_group is not None, but no supported backend found")
if self.bias is not None: if self.bias is not None:
result += self.bias result += self.bias
else: else:
@ -731,7 +693,7 @@ class LowBitLinear(nn.Linear):
result = MatMulLowBitCPU.apply(x, self.weight) result = MatMulLowBitCPU.apply(x, self.weight)
else: else:
from ipex_llm.utils.isa_checker import is_server, is_spr from ipex_llm.utils.isa_checker import is_server, is_spr
x0 = self.weight.data
# convert if necessary, and compute a linear result # convert if necessary, and compute a linear result
if is_server() and (not is_spr()) and \ if is_server() and (not is_spr()) and \
self.qtype == SYM_INT4 and x_2d.shape[0] >= TORCH_LINEAR_THRESHOLD: self.qtype == SYM_INT4 and x_2d.shape[0] >= TORCH_LINEAR_THRESHOLD:

View file

@ -259,19 +259,6 @@ def mlp_fusion_check(x, qtype, training):
return True return True
def use_xmx(x: torch.Tensor, qtype: int):
device = get_xpu_device_name(x.device)
return (
device in ["arc", "pvc"]
and qtype in [SYM_INT4, SYM_INT8, FP8E4, FP8E5, WOQ_INT4]
and (
(device == "pvc" and 1 < x.size(0) <= 16)
or
(device != "pvc" and 1 < x.size(0) <= 64)
)
)
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
batch, num_key_value_heads, slen, head_dim = hidden_states.shape batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1: if n_rep == 1:

View file

@ -20,9 +20,10 @@ import xe_batch
import xe_addons import xe_addons
# @torch.library.register_fake("ipex_llm::forward_new") @torch.library.register_fake("ipex_llm::forward_new")
# def _(x, weight, qtype, input_size): def _(x, weight, qtype, output_size):
# return ??? return torch.empty([x.size(0), output_size],
dtype=x.dtype, device=x.device)
# @torch.library.register_fake("ipex_llm::dequant") # @torch.library.register_fake("ipex_llm::dequant")