LLM: update fp16 Linear on ARC/FLEX (#10023)
This commit is contained in:
parent
a5c9dfdf91
commit
ccf8f613fb
1 changed files with 15 additions and 12 deletions
|
|
@ -575,7 +575,8 @@ class FP16Linear(nn.Linear):
|
||||||
self.weight_type = 2
|
self.weight_type = 2
|
||||||
return torch.ops.torch_ipex.matmul_bias_out(x, self.weight, self.bias)
|
return torch.ops.torch_ipex.matmul_bias_out(x, self.weight, self.bias)
|
||||||
else:
|
else:
|
||||||
if self.weight_type != 3:
|
if self.in_len == 4096 and self.weight_type != 3 or \
|
||||||
|
self.in_len == 11008 and self.weight_type != 1:
|
||||||
# convert weight first to use esimd fp16 kernel
|
# convert weight first to use esimd fp16 kernel
|
||||||
self.convert_weight_for_esimd_kernel()
|
self.convert_weight_for_esimd_kernel()
|
||||||
# esimd fp16 kernel for inference
|
# esimd fp16 kernel for inference
|
||||||
|
|
@ -591,14 +592,17 @@ class FP16Linear(nn.Linear):
|
||||||
invalidInputError(False,
|
invalidInputError(False,
|
||||||
"Please `pip install bigdl_core_xe_esimd` first.")
|
"Please `pip install bigdl_core_xe_esimd` first.")
|
||||||
|
|
||||||
if x_2d.shape[0] > 1:
|
if x_2d.shape[0] > 8:
|
||||||
# first token or batch size > 1, re-convert weight
|
# first token or batch size > 8, re-convert weight
|
||||||
|
if self.weight_type == 3:
|
||||||
original_weight = self.weight.data.transpose(1, 2)
|
original_weight = self.weight.data.transpose(1, 2)
|
||||||
original_weight = original_weight.reshape(self.out_len, self.in_len)
|
original_weight = original_weight.reshape(self.out_len, self.in_len)
|
||||||
result = F.linear(x_2d, original_weight.contiguous())
|
result = F.linear(x_2d, original_weight.contiguous())
|
||||||
del original_weight
|
del original_weight
|
||||||
else:
|
else:
|
||||||
# rest token, use esimd optimization
|
result = F.linear(x_2d, self.weight)
|
||||||
|
else:
|
||||||
|
# batch size <= 8, use esimd optimization
|
||||||
result = linear_fp16_esimd.forward(x_2d, self.weight.data)
|
result = linear_fp16_esimd.forward(x_2d, self.weight.data)
|
||||||
|
|
||||||
new_shape = x_shape[:-1] + (self.out_len,)
|
new_shape = x_shape[:-1] + (self.out_len,)
|
||||||
|
|
@ -632,9 +636,8 @@ class FP16Linear(nn.Linear):
|
||||||
trans_weight = self.weight.data.transpose(0, 1)
|
trans_weight = self.weight.data.transpose(0, 1)
|
||||||
else:
|
else:
|
||||||
trans_weight = self.weight.data
|
trans_weight = self.weight.data
|
||||||
trans_weight = trans_weight.data.reshape(m//8, 8, n)
|
self.weight.data = trans_weight.contiguous()
|
||||||
trans_weight = trans_weight.transpose(1, 2).contiguous()
|
self.weight_type = 1
|
||||||
self.weight.data = trans_weight
|
|
||||||
elif self.in_len == 4096:
|
elif self.in_len == 4096:
|
||||||
if self.weight_type == 2:
|
if self.weight_type == 2:
|
||||||
trans_weight = self.weight.data.transpose(0, 1)
|
trans_weight = self.weight.data.transpose(0, 1)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue