diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index 5652b02f..a8cd05df 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -226,6 +226,7 @@ def _optimize_post(model): module = importlib.import_module(modeling_module_name) from bigdl.llm.transformers.models.chatglm2 import chatglm2_attention_forward_8eb45c from bigdl.llm.transformers.models.chatglm2 import core_attn_forward_8eb45c + from bigdl.llm.transformers.models.chatglm2 import chatglm_rms_norm_forward convert_forward(model, module.SelfAttention, chatglm2_attention_forward_8eb45c @@ -235,7 +236,7 @@ def _optimize_post(model): core_attn_forward_8eb45c) convert_forward(model, module.RMSNorm, - llama_rms_norm_forward) + chatglm_rms_norm_forward) elif hasattr(model.config, 'vocab_size') and model.config.vocab_size == 130528: # chatglm-6b modeling_module_name = model.__class__.__module__ diff --git a/python/llm/src/bigdl/llm/transformers/models/chatglm2.py b/python/llm/src/bigdl/llm/transformers/models/chatglm2.py index 7dc90f86..fa54ea3e 100644 --- a/python/llm/src/bigdl/llm/transformers/models/chatglm2.py +++ b/python/llm/src/bigdl/llm/transformers/models/chatglm2.py @@ -74,6 +74,19 @@ def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Ten return torch.cat((x_out2, x_pass), dim=-1) +def chatglm_rms_norm_forward(self, hidden_states): + if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad): + hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states, + [self.weight.size(0)], self.weight) + else: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.eps) + return self.weight * hidden_states.to(input_dtype) + return hidden_states + + def chatglm2_attention_forward_8eb45c( self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True ):