diff --git a/python/llm/src/bigdl/llm/transformers/low_bit_linear.py b/python/llm/src/bigdl/llm/transformers/low_bit_linear.py index c51780ff..ffeb6cba 100644 --- a/python/llm/src/bigdl/llm/transformers/low_bit_linear.py +++ b/python/llm/src/bigdl/llm/transformers/low_bit_linear.py @@ -575,7 +575,8 @@ class FP16Linear(nn.Linear): self.weight_type = 2 return torch.ops.torch_ipex.matmul_bias_out(x, self.weight, self.bias) else: - if self.weight_type != 3: + if self.in_len == 4096 and self.weight_type != 3 or \ + self.in_len == 11008 and self.weight_type != 1: # convert weight first to use esimd fp16 kernel self.convert_weight_for_esimd_kernel() # esimd fp16 kernel for inference @@ -591,14 +592,17 @@ class FP16Linear(nn.Linear): invalidInputError(False, "Please `pip install bigdl_core_xe_esimd` first.") - if x_2d.shape[0] > 1: - # first token or batch size > 1, re-convert weight - original_weight = self.weight.data.transpose(1, 2) - original_weight = original_weight.reshape(self.out_len, self.in_len) - result = F.linear(x_2d, original_weight.contiguous()) - del original_weight + if x_2d.shape[0] > 8: + # first token or batch size > 8, re-convert weight + if self.weight_type == 3: + original_weight = self.weight.data.transpose(1, 2) + original_weight = original_weight.reshape(self.out_len, self.in_len) + result = F.linear(x_2d, original_weight.contiguous()) + del original_weight + else: + result = F.linear(x_2d, self.weight) else: - # rest token, use esimd optimization + # batch size <= 8, use esimd optimization result = linear_fp16_esimd.forward(x_2d, self.weight.data) new_shape = x_shape[:-1] + (self.out_len,) @@ -632,9 +636,8 @@ class FP16Linear(nn.Linear): trans_weight = self.weight.data.transpose(0, 1) else: trans_weight = self.weight.data - trans_weight = trans_weight.data.reshape(m//8, 8, n) - trans_weight = trans_weight.transpose(1, 2).contiguous() - self.weight.data = trans_weight + self.weight.data = trans_weight.contiguous() + self.weight_type = 1 elif self.in_len == 4096: if self.weight_type == 2: trans_weight = self.weight.data.transpose(0, 1) @@ -643,7 +646,7 @@ class FP16Linear(nn.Linear): trans_weight = trans_weight.data.reshape(m//16, 16, n) trans_weight = trans_weight.transpose(1, 2).contiguous() self.weight.data = trans_weight - self.weight_type = 3 + self.weight_type = 3 class BF16Linear(nn.Linear):