LLM: Add fuse rope and norm optimization for Aquila. (#9161)

* add fuse norm optimization.

* add fuse rope optimization
This commit is contained in:
Cengguang Zhang 2023-10-13 14:18:37 +08:00 committed by GitHub
parent e7aa67e141
commit 433f408081
2 changed files with 13 additions and 3 deletions

View file

@ -331,4 +331,7 @@ def optimize(model):
module.AquilaAttention, module.AquilaAttention,
aquila_attention_forward aquila_attention_forward
) )
convert_forward(model,
module.AquilaRMSNorm,
llama_rms_norm_forward)
return model return model

View file

@ -44,6 +44,7 @@ from torch import nn
from bigdl.llm.transformers.models.utils import extend_kv_cache, init_kv_cache, append_kv_cache from bigdl.llm.transformers.models.utils import extend_kv_cache, init_kv_cache, append_kv_cache
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
from bigdl.llm.utils.common import log4Error from bigdl.llm.utils.common import log4Error
KV_CACHE_ALLOC_BLOCK_LENGTH = 256 KV_CACHE_ALLOC_BLOCK_LENGTH = 256
@ -73,6 +74,12 @@ def aquila_attention_forward(
kv_seq_len = key_states.shape[-2] kv_seq_len = key_states.shape[-2]
if past_key_value is not None: if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2] kv_seq_len += past_key_value[0].shape[-2]
if query_states.device.type == "xpu" and not (self.training and query_states.requires_grad):
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
key_states,
position_ids,
"aquila")
else:
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
cos, sin, position_ids, "aquila") cos, sin, position_ids, "aquila")