update api usage of xe_batch & fp16 (#11164)

* update api usage

* update setup.py
This commit is contained in:
Ruonan Wang 2024-05-29 07:15:14 +00:00 committed by GitHub
parent e29e2f1c78
commit 9bfbf78bf4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 6 additions and 14 deletions

View file

@ -299,8 +299,7 @@ def setup_package():
"intel_extension_for_pytorch==2.1.10+xpu",
"bigdl-core-xe-21==" + CORE_XE_VERSION,
"bigdl-core-xe-batch-21==" + CORE_XE_VERSION,
"bigdl-core-xe-addons-21==" + CORE_XE_VERSION,
"bigdl-core-xe-esimd-21==" + CORE_XE_VERSION]
"bigdl-core-xe-addons-21==" + CORE_XE_VERSION]
xpu_21_requires += oneapi_2024_0_requires
# default to ipex 2.1 for linux and windows
xpu_requires = copy.deepcopy(xpu_21_requires)

View file

@ -720,8 +720,7 @@ class LowBitLinear(nn.Linear):
if use_batch_forward(x_2d, self.weight.qtype, self.out_len):
import xe_batch
result = xe_batch.batch_forward(x_2d, self.weight.data,
self.weight.qtype,
input_seq_size)
self.weight.qtype)
else:
result = xe_linear.forward_new(x_2d, self.weight.data, self.weight.qtype,
input_seq_size)
@ -730,8 +729,7 @@ class LowBitLinear(nn.Linear):
if use_batch_forward(x_2d, self.weight.qtype, self.out_len):
import xe_batch
result = xe_batch.batch_forward(x_2d, self.weight.data,
self.weight.qtype,
input_seq_size)
self.weight.qtype)
else:
result = xe_linear.forward_new(x_2d, self.weight.data, self.weight.qtype,
input_seq_size)
@ -843,13 +841,6 @@ class FP16Linear(nn.Linear):
if x_2d.is_contiguous() is False:
x_2d = x_2d.contiguous()
try:
import intel_extension_for_pytorch
import linear_fp16_esimd
except ModuleNotFoundError:
invalidInputError(False,
"Please `pip install bigdl_core_xe_esimd` first.")
if x_2d.shape[0] > 8:
# first token or batch size > 8, re-convert weight
if self.weight_type == 3:
@ -861,7 +852,9 @@ class FP16Linear(nn.Linear):
result = F.linear(x_2d, self.weight)
else:
# batch size <= 8, use esimd optimization
result = linear_fp16_esimd.forward(x_2d, self.weight.data)
import xe_batch
result = xe_batch.batch_forward(x_2d, self.weight.data,
self.qtype)
new_shape = x_shape[:-1] + (self.out_len,)
result = result.view(new_shape)