From c0ef70df02f1ed6be1241587bc3e76842e92cf2a Mon Sep 17 00:00:00 2001 From: Ruonan Wang <105281011+rnwang04@users.noreply.github.com> Date: Thu, 16 Nov 2023 14:42:16 +0800 Subject: [PATCH] llm: quick fix of fast_rms_norm (#9480) --- .../llm/src/bigdl/llm/transformers/models/baichuan2.py | 10 +++++----- python/llm/src/bigdl/llm/transformers/models/llama.py | 10 +++++----- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/models/baichuan2.py b/python/llm/src/bigdl/llm/transformers/models/baichuan2.py index 7827c1fb..cf3a0d2c 100644 --- a/python/llm/src/bigdl/llm/transformers/models/baichuan2.py +++ b/python/llm/src/bigdl/llm/transformers/models/baichuan2.py @@ -53,11 +53,11 @@ def baichuan_13b_rms_norm_forward(self, hidden_states): [self.weight.size(0)], self.weight) else: - hidden_states, _ = torch.ops.torch_ipex.fast_rms_norm(hidden_states, - [self.weight.size(0)], - self.weight, - None, - self.epsilon) + hidden_states = torch.ops.torch_ipex.fast_rms_norm(hidden_states, + [self.weight.size(0)], + self.weight, + None, + self.epsilon) else: input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) diff --git a/python/llm/src/bigdl/llm/transformers/models/llama.py b/python/llm/src/bigdl/llm/transformers/models/llama.py index 64a6dd33..a5f7e021 100644 --- a/python/llm/src/bigdl/llm/transformers/models/llama.py +++ b/python/llm/src/bigdl/llm/transformers/models/llama.py @@ -79,11 +79,11 @@ def llama_rms_norm_forward(self, hidden_states): hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states, [self.weight.size(0)], self.weight) else: - hidden_states, _ = torch.ops.torch_ipex.fast_rms_norm(hidden_states, - [self.weight.size(0)], - self.weight, - None, - self.variance_epsilon) + hidden_states = torch.ops.torch_ipex.fast_rms_norm(hidden_states, + [self.weight.size(0)], + self.weight, + None, + self.variance_epsilon) else: input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32)