[LLM] Add rope optimization for internlm (#9159)

* add rope and norm optimization for internlm and gptneox

* revert gptneox back and split with pr#9155 #

* add norm_forward

* style fix

* update

* update
This commit is contained in:
SONG Ge 2023-10-13 14:18:28 +08:00 committed by GitHub
parent f754ab3e60
commit e7aa67e141
2 changed files with 18 additions and 9 deletions

View file

@ -311,6 +311,10 @@ def optimize(model):
module.InternLMAttention,
internlm_attention_forward
)
convert_forward(model,
module.InternLMRMSNorm,
llama_rms_norm_forward
)
elif model.config.model_type == "qwen":
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)

View file

@ -74,15 +74,20 @@ def internlm_attention_forward(
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(
query_states,
key_states,
cos,
sin,
position_ids,
"internlm"
)
if query_states.device.type == "xpu" and not (self.training and query_states.requires_grad):
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
key_states,
position_ids,
"internlm")
else:
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(
query_states,
key_states,
cos,
sin,
position_ids,
"internlm")
# [bsz, nh, t, hd]
if past_key_value is not None: