optimize llama 3.2 rope (#12128)

This commit is contained in:
Yishuo Wang 2024-09-26 16:08:10 +08:00 committed by GitHub
parent 584c3489e7
commit a266528719
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -48,6 +48,7 @@ from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
from ipex_llm.utils.common import invalidInputError
from ipex_llm.transformers.models.common import attention_softmax
from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal
from ipex_llm.transformers.models.utils import should_use_fuse_rope
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
from ipex_llm.transformers.kv import DynamicNormalCache, DynamicFp8Cache
@ -111,6 +112,12 @@ def llama_model_forward(
# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)
# IPEX-LLM OPT start: use fused rope
if (should_use_fuse_rope(hidden_states, position_ids, False)
and self.rotary_emb.rope_type == "llama3"):
position_embeddings = self.rotary_emb.inv_freq
# IEPX_LLM OPT end
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
@ -179,11 +186,16 @@ def llama_attention_forward(
self.num_key_value_heads,
self.num_key_value_heads], dim=1)
if position_embeddings is None:
cos, sin = self.rotary_emb(value_states, position_ids)
if isinstance(position_embeddings, torch.Tensor):
import xe_addons
inv_freq = position_embeddings
xe_addons.rotary_half_inplaced(inv_freq, position_ids, query_states, key_states)
else:
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if position_embeddings is None:
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
key_states, value_states = past_key_value.update(key_states, value_states,