refactor xpu linear forward (#12768)
This commit is contained in:
parent
413d6c2b66
commit
0237ffb302
3 changed files with 32 additions and 82 deletions
|
|
@ -500,16 +500,16 @@ class MatMulLowBit(torch.autograd.Function):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@custom_fwd
|
@custom_fwd
|
||||||
def forward(ctx, A, weight, input_seq_size):
|
def forward(ctx, A, weight, output_size):
|
||||||
ctx.is_empty = False
|
ctx.is_empty = False
|
||||||
import xe_linear
|
import xe_linear
|
||||||
if weight.qtype == NF4:
|
if weight.qtype == NF4:
|
||||||
result = xe_linear.forward_new(A,
|
result = xe_linear.forward_new(A,
|
||||||
weight.data.view(torch.uint8),
|
weight.data.view(torch.uint8),
|
||||||
weight.qtype,
|
weight.qtype,
|
||||||
input_seq_size)
|
output_size)
|
||||||
else:
|
else:
|
||||||
result = xe_linear.forward_new(A, weight.data, weight.qtype, input_seq_size)
|
result = xe_linear.forward_new(A, weight.data, weight.qtype, output_size)
|
||||||
if any(ctx.needs_input_grad[:2]):
|
if any(ctx.needs_input_grad[:2]):
|
||||||
ctx.tensors = (A, weight)
|
ctx.tensors = (A, weight)
|
||||||
else:
|
else:
|
||||||
|
|
@ -627,89 +627,50 @@ class LowBitLinear(nn.Linear):
|
||||||
if self.optimize_lm_head:
|
if self.optimize_lm_head:
|
||||||
x = reshape_lm_head_input(x)
|
x = reshape_lm_head_input(x)
|
||||||
|
|
||||||
# [batch, input_num, in_len]
|
# [batch, seq_len, in_len] -> [batch, seq_len, out_len]
|
||||||
# input_num == token num for Transformer
|
new_shape = x.shape[:-1] + (self.out_len,)
|
||||||
x_shape = x.shape
|
|
||||||
# Output shape, e.g., [batch, input_num, out_len]
|
|
||||||
new_shape = x_shape[:-1] + (self.out_len,)
|
|
||||||
# Activation is empty tensor, e.g., [1, 0, 4096]
|
# Activation is empty tensor, e.g., [1, 0, 4096]
|
||||||
if 0 in x_shape:
|
if 0 in x.shape:
|
||||||
# return empty tensor with output shape, x.dtype and x.device
|
# return empty tensor with output shape, x.dtype and x.device
|
||||||
return torch.empty(new_shape, dtype=x.dtype, device=x.device)
|
return torch.empty(new_shape, dtype=x.dtype, device=x.device)
|
||||||
|
|
||||||
x_2d = x.contiguous().view(-1, x_shape[-1])
|
|
||||||
|
|
||||||
if self.act_order:
|
if self.act_order:
|
||||||
x_2d = x_2d[:, self.g_idx_map]
|
x = x[..., self.g_idx_map]
|
||||||
# x0 for weight
|
|
||||||
x0 = self.weight.data
|
|
||||||
|
|
||||||
if x0.device.type == "xpu":
|
x_2d = x.contiguous().view(-1, x.shape[-1])
|
||||||
# GPU logic
|
|
||||||
try:
|
|
||||||
import xe_linear
|
|
||||||
from ipex_llm.transformers.models.utils import use_xmx
|
|
||||||
except ModuleNotFoundError:
|
|
||||||
invalidInputError(False,
|
|
||||||
"Please `pip install bigdl_core_xe` first.")
|
|
||||||
|
|
||||||
if x_2d.is_contiguous() is False:
|
if self.weight.device.type == "xpu":
|
||||||
x_2d = x_2d.contiguous()
|
if is_training and x_2d.requires_grad:
|
||||||
|
result = MatMulLowBit.apply(x_2d, self.weight, self.out_len)
|
||||||
if len(x_shape) == 3:
|
|
||||||
input_seq_size = x_shape[1]
|
|
||||||
elif len(x_shape) < 3:
|
|
||||||
input_seq_size = 1
|
|
||||||
|
|
||||||
if is_training:
|
|
||||||
# training path
|
|
||||||
if x_2d.requires_grad:
|
|
||||||
result = MatMulLowBit.apply(x_2d, self.weight, input_seq_size)
|
|
||||||
else:
|
else:
|
||||||
if self.weight.qtype == NF4:
|
|
||||||
result = xe_linear.forward_new(x_2d,
|
|
||||||
self.weight.data.view(torch.uint8),
|
|
||||||
self.weight.qtype,
|
|
||||||
input_seq_size)
|
|
||||||
else:
|
|
||||||
result = xe_linear.forward_new(x_2d,
|
|
||||||
self.weight.data,
|
|
||||||
self.weight.qtype,
|
|
||||||
input_seq_size)
|
|
||||||
else:
|
|
||||||
# inference path
|
|
||||||
# current workaround to reduce first token latency of fp32 input
|
|
||||||
# sometimes fp16 cause nan and training instability
|
|
||||||
# disable the conversion when training
|
|
||||||
# TODO: may modify the input length condition for empty cache.
|
|
||||||
do_empty_cache = self.low_memory_mode and x_2d.shape[0] >= 1024
|
do_empty_cache = self.low_memory_mode and x_2d.shape[0] >= 1024
|
||||||
if do_empty_cache:
|
if do_empty_cache:
|
||||||
torch.xpu.empty_cache()
|
torch.xpu.empty_cache()
|
||||||
|
|
||||||
|
if self.qtype == NF4:
|
||||||
|
w = self.weight.data.view(torch.uint8)
|
||||||
|
else:
|
||||||
|
w = self.weight.data
|
||||||
|
|
||||||
if use_batch_forward(x_2d, self.weight.qtype, self.out_len):
|
if use_batch_forward(x_2d, self.weight.qtype, self.out_len):
|
||||||
import xe_batch
|
import xe_batch
|
||||||
result = xe_batch.batch_forward(x_2d, self.weight.data, self.weight.qtype)
|
result = xe_batch.batch_forward(x_2d, w, self.qtype)
|
||||||
elif (
|
elif not is_training and self.conver_to_half \
|
||||||
self.conver_to_half
|
and x_2d.shape[0] > 1 and x_2d.dtype == torch.float:
|
||||||
and x_2d.shape[0] > 1
|
import xe_linear
|
||||||
and x_2d.dtype == torch.float32
|
|
||||||
and not use_xmx(x_2d, self.weight.qtype)
|
|
||||||
):
|
|
||||||
x_2d = x_2d.half()
|
x_2d = x_2d.half()
|
||||||
result = xe_linear.forward_new(x_2d, self.weight.data,
|
result = xe_linear.forward_new(x_2d, w, self.qtype, self.out_len)
|
||||||
self.weight.qtype, input_seq_size)
|
|
||||||
result = result.to(x.dtype)
|
result = result.to(x.dtype)
|
||||||
else:
|
else:
|
||||||
if self.weight.qtype == NF4:
|
import xe_linear
|
||||||
result = xe_linear.forward_new(x_2d, self.weight.data.view(torch.uint8),
|
result = xe_linear.forward_new(x_2d, w, self.qtype, self.out_len)
|
||||||
self.weight.qtype, input_seq_size)
|
|
||||||
else:
|
|
||||||
result = xe_linear.forward_new(x_2d, self.weight.data,
|
|
||||||
self.weight.qtype, input_seq_size)
|
|
||||||
|
|
||||||
if do_empty_cache:
|
if do_empty_cache:
|
||||||
torch.xpu.empty_cache()
|
torch.xpu.empty_cache()
|
||||||
|
|
||||||
result = result.view(new_shape)
|
result = result.view(new_shape)
|
||||||
|
|
||||||
if self.mp_group is not None:
|
if self.mp_group is not None:
|
||||||
if get_use_vllm():
|
if get_use_vllm():
|
||||||
result = self.mp_group.all_reduce(result)
|
result = self.mp_group.all_reduce(result)
|
||||||
|
|
@ -718,6 +679,7 @@ class LowBitLinear(nn.Linear):
|
||||||
dist.inference_all_reduce(result, group=self.mp_group)
|
dist.inference_all_reduce(result, group=self.mp_group)
|
||||||
else:
|
else:
|
||||||
invalidInputError(False, "mp_group is not None, but no supported backend found")
|
invalidInputError(False, "mp_group is not None, but no supported backend found")
|
||||||
|
|
||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
result += self.bias
|
result += self.bias
|
||||||
else:
|
else:
|
||||||
|
|
@ -731,7 +693,7 @@ class LowBitLinear(nn.Linear):
|
||||||
result = MatMulLowBitCPU.apply(x, self.weight)
|
result = MatMulLowBitCPU.apply(x, self.weight)
|
||||||
else:
|
else:
|
||||||
from ipex_llm.utils.isa_checker import is_server, is_spr
|
from ipex_llm.utils.isa_checker import is_server, is_spr
|
||||||
|
x0 = self.weight.data
|
||||||
# convert if necessary, and compute a linear result
|
# convert if necessary, and compute a linear result
|
||||||
if is_server() and (not is_spr()) and \
|
if is_server() and (not is_spr()) and \
|
||||||
self.qtype == SYM_INT4 and x_2d.shape[0] >= TORCH_LINEAR_THRESHOLD:
|
self.qtype == SYM_INT4 and x_2d.shape[0] >= TORCH_LINEAR_THRESHOLD:
|
||||||
|
|
|
||||||
|
|
@ -259,19 +259,6 @@ def mlp_fusion_check(x, qtype, training):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def use_xmx(x: torch.Tensor, qtype: int):
|
|
||||||
device = get_xpu_device_name(x.device)
|
|
||||||
return (
|
|
||||||
device in ["arc", "pvc"]
|
|
||||||
and qtype in [SYM_INT4, SYM_INT8, FP8E4, FP8E5, WOQ_INT4]
|
|
||||||
and (
|
|
||||||
(device == "pvc" and 1 < x.size(0) <= 16)
|
|
||||||
or
|
|
||||||
(device != "pvc" and 1 < x.size(0) <= 64)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||||
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||||
if n_rep == 1:
|
if n_rep == 1:
|
||||||
|
|
|
||||||
|
|
@ -20,9 +20,10 @@ import xe_batch
|
||||||
import xe_addons
|
import xe_addons
|
||||||
|
|
||||||
|
|
||||||
# @torch.library.register_fake("ipex_llm::forward_new")
|
@torch.library.register_fake("ipex_llm::forward_new")
|
||||||
# def _(x, weight, qtype, input_size):
|
def _(x, weight, qtype, output_size):
|
||||||
# return ???
|
return torch.empty([x.size(0), output_size],
|
||||||
|
dtype=x.dtype, device=x.device)
|
||||||
|
|
||||||
|
|
||||||
# @torch.library.register_fake("ipex_llm::dequant")
|
# @torch.library.register_fake("ipex_llm::dequant")
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue