diff --git a/python/llm/src/bigdl/llm/transformers/models/baichuan2.py b/python/llm/src/bigdl/llm/transformers/models/baichuan2.py index 7827c1fb..cf3a0d2c 100644 --- a/python/llm/src/bigdl/llm/transformers/models/baichuan2.py +++ b/python/llm/src/bigdl/llm/transformers/models/baichuan2.py @@ -53,11 +53,11 @@ def baichuan_13b_rms_norm_forward(self, hidden_states): [self.weight.size(0)], self.weight) else: - hidden_states, _ = torch.ops.torch_ipex.fast_rms_norm(hidden_states, - [self.weight.size(0)], - self.weight, - None, - self.epsilon) + hidden_states = torch.ops.torch_ipex.fast_rms_norm(hidden_states, + [self.weight.size(0)], + self.weight, + None, + self.epsilon) else: input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) diff --git a/python/llm/src/bigdl/llm/transformers/models/llama.py b/python/llm/src/bigdl/llm/transformers/models/llama.py index 64a6dd33..a5f7e021 100644 --- a/python/llm/src/bigdl/llm/transformers/models/llama.py +++ b/python/llm/src/bigdl/llm/transformers/models/llama.py @@ -79,11 +79,11 @@ def llama_rms_norm_forward(self, hidden_states): hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states, [self.weight.size(0)], self.weight) else: - hidden_states, _ = torch.ops.torch_ipex.fast_rms_norm(hidden_states, - [self.weight.size(0)], - self.weight, - None, - self.variance_epsilon) + hidden_states = torch.ops.torch_ipex.fast_rms_norm(hidden_states, + [self.weight.size(0)], + self.weight, + None, + self.variance_epsilon) else: input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32)