update api usage of xe_batch & fp16 (#11164)
* update api usage * update setup.py
This commit is contained in:
parent
e29e2f1c78
commit
9bfbf78bf4
2 changed files with 6 additions and 14 deletions
|
|
@ -299,8 +299,7 @@ def setup_package():
|
|||
"intel_extension_for_pytorch==2.1.10+xpu",
|
||||
"bigdl-core-xe-21==" + CORE_XE_VERSION,
|
||||
"bigdl-core-xe-batch-21==" + CORE_XE_VERSION,
|
||||
"bigdl-core-xe-addons-21==" + CORE_XE_VERSION,
|
||||
"bigdl-core-xe-esimd-21==" + CORE_XE_VERSION]
|
||||
"bigdl-core-xe-addons-21==" + CORE_XE_VERSION]
|
||||
xpu_21_requires += oneapi_2024_0_requires
|
||||
# default to ipex 2.1 for linux and windows
|
||||
xpu_requires = copy.deepcopy(xpu_21_requires)
|
||||
|
|
|
|||
|
|
@ -720,8 +720,7 @@ class LowBitLinear(nn.Linear):
|
|||
if use_batch_forward(x_2d, self.weight.qtype, self.out_len):
|
||||
import xe_batch
|
||||
result = xe_batch.batch_forward(x_2d, self.weight.data,
|
||||
self.weight.qtype,
|
||||
input_seq_size)
|
||||
self.weight.qtype)
|
||||
else:
|
||||
result = xe_linear.forward_new(x_2d, self.weight.data, self.weight.qtype,
|
||||
input_seq_size)
|
||||
|
|
@ -730,8 +729,7 @@ class LowBitLinear(nn.Linear):
|
|||
if use_batch_forward(x_2d, self.weight.qtype, self.out_len):
|
||||
import xe_batch
|
||||
result = xe_batch.batch_forward(x_2d, self.weight.data,
|
||||
self.weight.qtype,
|
||||
input_seq_size)
|
||||
self.weight.qtype)
|
||||
else:
|
||||
result = xe_linear.forward_new(x_2d, self.weight.data, self.weight.qtype,
|
||||
input_seq_size)
|
||||
|
|
@ -843,13 +841,6 @@ class FP16Linear(nn.Linear):
|
|||
if x_2d.is_contiguous() is False:
|
||||
x_2d = x_2d.contiguous()
|
||||
|
||||
try:
|
||||
import intel_extension_for_pytorch
|
||||
import linear_fp16_esimd
|
||||
except ModuleNotFoundError:
|
||||
invalidInputError(False,
|
||||
"Please `pip install bigdl_core_xe_esimd` first.")
|
||||
|
||||
if x_2d.shape[0] > 8:
|
||||
# first token or batch size > 8, re-convert weight
|
||||
if self.weight_type == 3:
|
||||
|
|
@ -861,7 +852,9 @@ class FP16Linear(nn.Linear):
|
|||
result = F.linear(x_2d, self.weight)
|
||||
else:
|
||||
# batch size <= 8, use esimd optimization
|
||||
result = linear_fp16_esimd.forward(x_2d, self.weight.data)
|
||||
import xe_batch
|
||||
result = xe_batch.batch_forward(x_2d, self.weight.data,
|
||||
self.qtype)
|
||||
|
||||
new_shape = x_shape[:-1] + (self.out_len,)
|
||||
result = result.view(new_shape)
|
||||
|
|
|
|||
Loading…
Reference in a new issue