From 70ed9397f907a14763624c4f093dd1f64a39db76 Mon Sep 17 00:00:00 2001 From: binbin Deng <108676127+plusbang@users.noreply.github.com> Date: Thu, 11 Apr 2024 17:03:56 +0800 Subject: [PATCH] LLM: fix AttributeError of FP16Linear (#10740) --- python/llm/src/ipex_llm/transformers/low_bit_linear.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/low_bit_linear.py b/python/llm/src/ipex_llm/transformers/low_bit_linear.py index 0acca42a..c390dd90 100644 --- a/python/llm/src/ipex_llm/transformers/low_bit_linear.py +++ b/python/llm/src/ipex_llm/transformers/low_bit_linear.py @@ -729,7 +729,7 @@ class LowBitLinear(nn.Linear): class FP16Linear(nn.Linear): def __init__(self, input_features, output_features, bias=True, - mp_group=None, weight_type=1, + mp_group=None, weight_type=1, enable_xetla=False, optimize_lm_head=False): super().__init__(input_features, output_features, bias) self.in_len = input_features @@ -743,6 +743,7 @@ class FP16Linear(nn.Linear): # weigh_type = 3 means weight has been transposed by esimd method self.weight_type = 1 self.optimize_lm_head = optimize_lm_head + self.enable_xetla = enable_xetla def forward(self, x: torch.Tensor): # only work for GPU @@ -849,7 +850,7 @@ class FP16Linear(nn.Linear): class BF16Linear(nn.Linear): def __init__(self, input_features, output_features, bias=True, - mp_group=None, compute_dtype=None, + mp_group=None, compute_dtype=None, enable_xetla=False, optimize_lm_head=False): super().__init__(input_features, output_features, bias) self.in_len = input_features @@ -860,6 +861,7 @@ class BF16Linear(nn.Linear): self.mp_group = mp_group self.compute_dtype = compute_dtype self.optimize_lm_head = optimize_lm_head + self.enable_xetla = enable_xetla def forward(self, x: torch.Tensor): if self.optimize_lm_head: