[LLM] Reimplement chatglm fuse rms optimization (#9260)
* re-implement chatglm rope rms * update
This commit is contained in:
parent
5a2ce421af
commit
bd5215d75b
2 changed files with 15 additions and 1 deletions
|
|
@ -226,6 +226,7 @@ def _optimize_post(model):
|
|||
module = importlib.import_module(modeling_module_name)
|
||||
from bigdl.llm.transformers.models.chatglm2 import chatglm2_attention_forward_8eb45c
|
||||
from bigdl.llm.transformers.models.chatglm2 import core_attn_forward_8eb45c
|
||||
from bigdl.llm.transformers.models.chatglm2 import chatglm_rms_norm_forward
|
||||
convert_forward(model,
|
||||
module.SelfAttention,
|
||||
chatglm2_attention_forward_8eb45c
|
||||
|
|
@ -235,7 +236,7 @@ def _optimize_post(model):
|
|||
core_attn_forward_8eb45c)
|
||||
convert_forward(model,
|
||||
module.RMSNorm,
|
||||
llama_rms_norm_forward)
|
||||
chatglm_rms_norm_forward)
|
||||
elif hasattr(model.config, 'vocab_size') and model.config.vocab_size == 130528:
|
||||
# chatglm-6b
|
||||
modeling_module_name = model.__class__.__module__
|
||||
|
|
|
|||
|
|
@ -74,6 +74,19 @@ def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Ten
|
|||
return torch.cat((x_out2, x_pass), dim=-1)
|
||||
|
||||
|
||||
def chatglm_rms_norm_forward(self, hidden_states):
|
||||
if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
|
||||
hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states,
|
||||
[self.weight.size(0)], self.weight)
|
||||
else:
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
|
||||
return self.weight * hidden_states.to(input_dtype)
|
||||
return hidden_states
|
||||
|
||||
|
||||
def chatglm2_attention_forward_8eb45c(
|
||||
self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True
|
||||
):
|
||||
|
|
|
|||
Loading…
Reference in a new issue