optimize llama 3.2 rope (#12128)
This commit is contained in:
parent
584c3489e7
commit
a266528719
1 changed files with 16 additions and 4 deletions
|
|
@ -48,6 +48,7 @@ from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
|
|||
from ipex_llm.utils.common import invalidInputError
|
||||
from ipex_llm.transformers.models.common import attention_softmax
|
||||
from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal
|
||||
from ipex_llm.transformers.models.utils import should_use_fuse_rope
|
||||
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
|
||||
from ipex_llm.transformers.kv import DynamicNormalCache, DynamicFp8Cache
|
||||
|
||||
|
|
@ -111,6 +112,12 @@ def llama_model_forward(
|
|||
# create position embeddings to be shared across the decoder layers
|
||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||
|
||||
# IPEX-LLM OPT start: use fused rope
|
||||
if (should_use_fuse_rope(hidden_states, position_ids, False)
|
||||
and self.rotary_emb.rope_type == "llama3"):
|
||||
position_embeddings = self.rotary_emb.inv_freq
|
||||
# IEPX_LLM OPT end
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
|
|
@ -179,6 +186,11 @@ def llama_attention_forward(
|
|||
self.num_key_value_heads,
|
||||
self.num_key_value_heads], dim=1)
|
||||
|
||||
if isinstance(position_embeddings, torch.Tensor):
|
||||
import xe_addons
|
||||
inv_freq = position_embeddings
|
||||
xe_addons.rotary_half_inplaced(inv_freq, position_ids, query_states, key_states)
|
||||
else:
|
||||
if position_embeddings is None:
|
||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||
else:
|
||||
|
|
|
|||
Loading…
Reference in a new issue