[LLM] Reimplement chatglm fuse rms optimization (#9260)

* re-implement chatglm rope rms

* update
This commit is contained in:
SONG Ge 2023-10-24 16:35:12 +08:00 committed by GitHub
parent 5a2ce421af
commit bd5215d75b
2 changed files with 15 additions and 1 deletions

View file

@ -226,6 +226,7 @@ def _optimize_post(model):
module = importlib.import_module(modeling_module_name) module = importlib.import_module(modeling_module_name)
from bigdl.llm.transformers.models.chatglm2 import chatglm2_attention_forward_8eb45c from bigdl.llm.transformers.models.chatglm2 import chatglm2_attention_forward_8eb45c
from bigdl.llm.transformers.models.chatglm2 import core_attn_forward_8eb45c from bigdl.llm.transformers.models.chatglm2 import core_attn_forward_8eb45c
from bigdl.llm.transformers.models.chatglm2 import chatglm_rms_norm_forward
convert_forward(model, convert_forward(model,
module.SelfAttention, module.SelfAttention,
chatglm2_attention_forward_8eb45c chatglm2_attention_forward_8eb45c
@ -235,7 +236,7 @@ def _optimize_post(model):
core_attn_forward_8eb45c) core_attn_forward_8eb45c)
convert_forward(model, convert_forward(model,
module.RMSNorm, module.RMSNorm,
llama_rms_norm_forward) chatglm_rms_norm_forward)
elif hasattr(model.config, 'vocab_size') and model.config.vocab_size == 130528: elif hasattr(model.config, 'vocab_size') and model.config.vocab_size == 130528:
# chatglm-6b # chatglm-6b
modeling_module_name = model.__class__.__module__ modeling_module_name = model.__class__.__module__

View file

@ -74,6 +74,19 @@ def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Ten
return torch.cat((x_out2, x_pass), dim=-1) return torch.cat((x_out2, x_pass), dim=-1)
def chatglm_rms_norm_forward(self, hidden_states):
if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states,
[self.weight.size(0)], self.weight)
else:
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
return self.weight * hidden_states.to(input_dtype)
return hidden_states
def chatglm2_attention_forward_8eb45c( def chatglm2_attention_forward_8eb45c(
self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True
): ):