From 13d47955a810067fd04db95c713b4fb4c5aae121 Mon Sep 17 00:00:00 2001 From: Xin Qiu Date: Thu, 7 Dec 2023 09:21:41 +0800 Subject: [PATCH] use fused rms norm in chatglm2 and baichuan (#9613) * use fused rms norm in chatglm2 and baichuan * style fix --- .../llm/transformers/models/baichuan2.py | 17 ++++++++--------- .../bigdl/llm/transformers/models/chatglm2.py | 17 ++++++++--------- .../bigdl/llm/transformers/models/llama.py | 19 +++++++------------ 3 files changed, 23 insertions(+), 30 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/models/baichuan2.py b/python/llm/src/bigdl/llm/transformers/models/baichuan2.py index 8b42cb3f..bf1add5f 100644 --- a/python/llm/src/bigdl/llm/transformers/models/baichuan2.py +++ b/python/llm/src/bigdl/llm/transformers/models/baichuan2.py @@ -47,28 +47,27 @@ KV_CACHE_ALLOC_BLOCK_LENGTH = 256 def baichuan_13b_rms_norm_forward(self, hidden_states): - optimized_rms_norm = False if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad): if get_ipex_version() <= "2.0.110+xpu": - if self.epsilon == 1e-6: - hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states, - [self.weight.size(0)], - self.weight) - optimized_rms_norm = True + import linear_q4_0 + hidden_states = linear_q4_0.fused_rms_norm(hidden_states, + [self.weight.size(0)], + self.weight, + None, + self.epsilon) else: hidden_states = torch.ops.torch_ipex.fast_rms_norm(hidden_states, [self.weight.size(0)], self.weight, None, self.epsilon) - optimized_rms_norm = True - if not optimized_rms_norm: + return hidden_states + else: input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.epsilon) return self.weight * hidden_states.to(input_dtype) - return hidden_states def baichuan_attention_forward_7b( diff --git a/python/llm/src/bigdl/llm/transformers/models/chatglm2.py b/python/llm/src/bigdl/llm/transformers/models/chatglm2.py index 41df1038..d4c3e326 100644 --- a/python/llm/src/bigdl/llm/transformers/models/chatglm2.py +++ b/python/llm/src/bigdl/llm/transformers/models/chatglm2.py @@ -77,14 +77,14 @@ def apply_rotary_pos_emb_chatglm(x: torch.Tensor, rope_cache: torch.Tensor) -> t def chatglm_rms_norm_forward(self, hidden_states): - optimized_rms_norm = False if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad): if get_ipex_version() <= "2.0.110+xpu": - if self.eps == 1e-6: - hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states, - [self.weight.size(0)], - self.weight) - optimized_rms_norm = True + import linear_q4_0 + hidden_states = linear_q4_0.fused_rms_norm(hidden_states, + [self.weight.size(0)], + self.weight, + None, + self.eps) else: # for ipex >= 2.1 hidden_states = torch.ops.torch_ipex.fast_rms_norm(hidden_states, @@ -92,14 +92,13 @@ def chatglm_rms_norm_forward(self, hidden_states): self.weight, None, # bias self.eps) - optimized_rms_norm = True - if not optimized_rms_norm: + return hidden_states + else: input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.eps) return self.weight * hidden_states.to(input_dtype) - return hidden_states def chatglm2_model_forward( diff --git a/python/llm/src/bigdl/llm/transformers/models/llama.py b/python/llm/src/bigdl/llm/transformers/models/llama.py index 7e8a77be..024de8cd 100644 --- a/python/llm/src/bigdl/llm/transformers/models/llama.py +++ b/python/llm/src/bigdl/llm/transformers/models/llama.py @@ -76,30 +76,25 @@ def get_ipex_version(): def llama_rms_norm_forward(self, hidden_states): if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad): if get_ipex_version() <= "2.0.110+xpu": - if self.variance_epsilon == 1e-6: - hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states, - [self.weight.size(0)], - self.weight) - else: - import linear_q4_0 - hidden_states = linear_q4_0.fused_rms_norm(hidden_states, - [self.weight.size(0)], - self.weight, - None, - self.variance_epsilon) + import linear_q4_0 + hidden_states = linear_q4_0.fused_rms_norm(hidden_states, + [self.weight.size(0)], + self.weight, + None, + self.variance_epsilon) else: hidden_states = torch.ops.torch_ipex.fast_rms_norm(hidden_states, [self.weight.size(0)], self.weight, None, self.variance_epsilon) + return hidden_states else: input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) - return hidden_states def llama_attention_forward_4_31(