LLM: update fp16 Linear on ARC/FLEX (#10023)

This commit is contained in:
Ruonan Wang 2024-01-29 18:25:26 +08:00 committed by GitHub
parent a5c9dfdf91
commit ccf8f613fb

View file

@ -575,7 +575,8 @@ class FP16Linear(nn.Linear):
self.weight_type = 2 self.weight_type = 2
return torch.ops.torch_ipex.matmul_bias_out(x, self.weight, self.bias) return torch.ops.torch_ipex.matmul_bias_out(x, self.weight, self.bias)
else: else:
if self.weight_type != 3: if self.in_len == 4096 and self.weight_type != 3 or \
self.in_len == 11008 and self.weight_type != 1:
# convert weight first to use esimd fp16 kernel # convert weight first to use esimd fp16 kernel
self.convert_weight_for_esimd_kernel() self.convert_weight_for_esimd_kernel()
# esimd fp16 kernel for inference # esimd fp16 kernel for inference
@ -591,14 +592,17 @@ class FP16Linear(nn.Linear):
invalidInputError(False, invalidInputError(False,
"Please `pip install bigdl_core_xe_esimd` first.") "Please `pip install bigdl_core_xe_esimd` first.")
if x_2d.shape[0] > 1: if x_2d.shape[0] > 8:
# first token or batch size > 1, re-convert weight # first token or batch size > 8, re-convert weight
original_weight = self.weight.data.transpose(1, 2) if self.weight_type == 3:
original_weight = original_weight.reshape(self.out_len, self.in_len) original_weight = self.weight.data.transpose(1, 2)
result = F.linear(x_2d, original_weight.contiguous()) original_weight = original_weight.reshape(self.out_len, self.in_len)
del original_weight result = F.linear(x_2d, original_weight.contiguous())
del original_weight
else:
result = F.linear(x_2d, self.weight)
else: else:
# rest token, use esimd optimization # batch size <= 8, use esimd optimization
result = linear_fp16_esimd.forward(x_2d, self.weight.data) result = linear_fp16_esimd.forward(x_2d, self.weight.data)
new_shape = x_shape[:-1] + (self.out_len,) new_shape = x_shape[:-1] + (self.out_len,)
@ -632,9 +636,8 @@ class FP16Linear(nn.Linear):
trans_weight = self.weight.data.transpose(0, 1) trans_weight = self.weight.data.transpose(0, 1)
else: else:
trans_weight = self.weight.data trans_weight = self.weight.data
trans_weight = trans_weight.data.reshape(m//8, 8, n) self.weight.data = trans_weight.contiguous()
trans_weight = trans_weight.transpose(1, 2).contiguous() self.weight_type = 1
self.weight.data = trans_weight
elif self.in_len == 4096: elif self.in_len == 4096:
if self.weight_type == 2: if self.weight_type == 2:
trans_weight = self.weight.data.transpose(0, 1) trans_weight = self.weight.data.transpose(0, 1)
@ -643,7 +646,7 @@ class FP16Linear(nn.Linear):
trans_weight = trans_weight.data.reshape(m//16, 16, n) trans_weight = trans_weight.data.reshape(m//16, 16, n)
trans_weight = trans_weight.transpose(1, 2).contiguous() trans_weight = trans_weight.transpose(1, 2).contiguous()
self.weight.data = trans_weight self.weight.data = trans_weight
self.weight_type = 3 self.weight_type = 3
class BF16Linear(nn.Linear): class BF16Linear(nn.Linear):