[LLM] Add rope optimization for internlm (#9159)

* add rope and norm optimization for internlm and gptneox

* revert gptneox back and split with pr#9155 #

* add norm_forward

* style fix

* update

* update
This commit is contained in:
SONG Ge 2023-10-13 14:18:28 +08:00 committed by GitHub
parent f754ab3e60
commit e7aa67e141
2 changed files with 18 additions and 9 deletions

View file

@ -311,6 +311,10 @@ def optimize(model):
module.InternLMAttention, module.InternLMAttention,
internlm_attention_forward internlm_attention_forward
) )
convert_forward(model,
module.InternLMRMSNorm,
llama_rms_norm_forward
)
elif model.config.model_type == "qwen": elif model.config.model_type == "qwen":
modeling_module_name = model.__class__.__module__ modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name) module = importlib.import_module(modeling_module_name)

View file

@ -74,15 +74,20 @@ def internlm_attention_forward(
kv_seq_len = key_states.shape[-2] kv_seq_len = key_states.shape[-2]
if past_key_value is not None: if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2] kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) if query_states.device.type == "xpu" and not (self.training and query_states.requires_grad):
query_states, key_states = apply_rotary_pos_emb( query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
query_states, key_states,
key_states, position_ids,
cos, "internlm")
sin, else:
position_ids, cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
"internlm" query_states, key_states = apply_rotary_pos_emb(
) query_states,
key_states,
cos,
sin,
position_ids,
"internlm")
# [bsz, nh, t, hd] # [bsz, nh, t, hd]
if past_key_value is not None: if past_key_value is not None: