LLM: update fp16 Linear on ARC/FLEX (#10023)
This commit is contained in:
parent
a5c9dfdf91
commit
ccf8f613fb
1 changed files with 15 additions and 12 deletions
|
|
@ -575,7 +575,8 @@ class FP16Linear(nn.Linear):
|
|||
self.weight_type = 2
|
||||
return torch.ops.torch_ipex.matmul_bias_out(x, self.weight, self.bias)
|
||||
else:
|
||||
if self.weight_type != 3:
|
||||
if self.in_len == 4096 and self.weight_type != 3 or \
|
||||
self.in_len == 11008 and self.weight_type != 1:
|
||||
# convert weight first to use esimd fp16 kernel
|
||||
self.convert_weight_for_esimd_kernel()
|
||||
# esimd fp16 kernel for inference
|
||||
|
|
@ -591,14 +592,17 @@ class FP16Linear(nn.Linear):
|
|||
invalidInputError(False,
|
||||
"Please `pip install bigdl_core_xe_esimd` first.")
|
||||
|
||||
if x_2d.shape[0] > 1:
|
||||
# first token or batch size > 1, re-convert weight
|
||||
if x_2d.shape[0] > 8:
|
||||
# first token or batch size > 8, re-convert weight
|
||||
if self.weight_type == 3:
|
||||
original_weight = self.weight.data.transpose(1, 2)
|
||||
original_weight = original_weight.reshape(self.out_len, self.in_len)
|
||||
result = F.linear(x_2d, original_weight.contiguous())
|
||||
del original_weight
|
||||
else:
|
||||
# rest token, use esimd optimization
|
||||
result = F.linear(x_2d, self.weight)
|
||||
else:
|
||||
# batch size <= 8, use esimd optimization
|
||||
result = linear_fp16_esimd.forward(x_2d, self.weight.data)
|
||||
|
||||
new_shape = x_shape[:-1] + (self.out_len,)
|
||||
|
|
@ -632,9 +636,8 @@ class FP16Linear(nn.Linear):
|
|||
trans_weight = self.weight.data.transpose(0, 1)
|
||||
else:
|
||||
trans_weight = self.weight.data
|
||||
trans_weight = trans_weight.data.reshape(m//8, 8, n)
|
||||
trans_weight = trans_weight.transpose(1, 2).contiguous()
|
||||
self.weight.data = trans_weight
|
||||
self.weight.data = trans_weight.contiguous()
|
||||
self.weight_type = 1
|
||||
elif self.in_len == 4096:
|
||||
if self.weight_type == 2:
|
||||
trans_weight = self.weight.data.transpose(0, 1)
|
||||
|
|
|
|||
Loading…
Reference in a new issue