[LLM] Reimplement chatglm fuse rms optimization (#9260)
* re-implement chatglm rope rms * update
This commit is contained in:
parent
5a2ce421af
commit
bd5215d75b
2 changed files with 15 additions and 1 deletions
|
|
@ -226,6 +226,7 @@ def _optimize_post(model):
|
||||||
module = importlib.import_module(modeling_module_name)
|
module = importlib.import_module(modeling_module_name)
|
||||||
from bigdl.llm.transformers.models.chatglm2 import chatglm2_attention_forward_8eb45c
|
from bigdl.llm.transformers.models.chatglm2 import chatglm2_attention_forward_8eb45c
|
||||||
from bigdl.llm.transformers.models.chatglm2 import core_attn_forward_8eb45c
|
from bigdl.llm.transformers.models.chatglm2 import core_attn_forward_8eb45c
|
||||||
|
from bigdl.llm.transformers.models.chatglm2 import chatglm_rms_norm_forward
|
||||||
convert_forward(model,
|
convert_forward(model,
|
||||||
module.SelfAttention,
|
module.SelfAttention,
|
||||||
chatglm2_attention_forward_8eb45c
|
chatglm2_attention_forward_8eb45c
|
||||||
|
|
@ -235,7 +236,7 @@ def _optimize_post(model):
|
||||||
core_attn_forward_8eb45c)
|
core_attn_forward_8eb45c)
|
||||||
convert_forward(model,
|
convert_forward(model,
|
||||||
module.RMSNorm,
|
module.RMSNorm,
|
||||||
llama_rms_norm_forward)
|
chatglm_rms_norm_forward)
|
||||||
elif hasattr(model.config, 'vocab_size') and model.config.vocab_size == 130528:
|
elif hasattr(model.config, 'vocab_size') and model.config.vocab_size == 130528:
|
||||||
# chatglm-6b
|
# chatglm-6b
|
||||||
modeling_module_name = model.__class__.__module__
|
modeling_module_name = model.__class__.__module__
|
||||||
|
|
|
||||||
|
|
@ -74,6 +74,19 @@ def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Ten
|
||||||
return torch.cat((x_out2, x_pass), dim=-1)
|
return torch.cat((x_out2, x_pass), dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
def chatglm_rms_norm_forward(self, hidden_states):
|
||||||
|
if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
|
||||||
|
hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states,
|
||||||
|
[self.weight.size(0)], self.weight)
|
||||||
|
else:
|
||||||
|
input_dtype = hidden_states.dtype
|
||||||
|
hidden_states = hidden_states.to(torch.float32)
|
||||||
|
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||||
|
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
|
||||||
|
return self.weight * hidden_states.to(input_dtype)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
def chatglm2_attention_forward_8eb45c(
|
def chatglm2_attention_forward_8eb45c(
|
||||||
self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True
|
self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True
|
||||||
):
|
):
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue