optimize llama 3.2 rope (#12128)
This commit is contained in:
parent
584c3489e7
commit
a266528719
1 changed files with 16 additions and 4 deletions
|
|
@ -48,6 +48,7 @@ from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
|
||||||
from ipex_llm.utils.common import invalidInputError
|
from ipex_llm.utils.common import invalidInputError
|
||||||
from ipex_llm.transformers.models.common import attention_softmax
|
from ipex_llm.transformers.models.common import attention_softmax
|
||||||
from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal
|
from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal
|
||||||
|
from ipex_llm.transformers.models.utils import should_use_fuse_rope
|
||||||
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
|
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
|
||||||
from ipex_llm.transformers.kv import DynamicNormalCache, DynamicFp8Cache
|
from ipex_llm.transformers.kv import DynamicNormalCache, DynamicFp8Cache
|
||||||
|
|
||||||
|
|
@ -111,6 +112,12 @@ def llama_model_forward(
|
||||||
# create position embeddings to be shared across the decoder layers
|
# create position embeddings to be shared across the decoder layers
|
||||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||||
|
|
||||||
|
# IPEX-LLM OPT start: use fused rope
|
||||||
|
if (should_use_fuse_rope(hidden_states, position_ids, False)
|
||||||
|
and self.rotary_emb.rope_type == "llama3"):
|
||||||
|
position_embeddings = self.rotary_emb.inv_freq
|
||||||
|
# IEPX_LLM OPT end
|
||||||
|
|
||||||
# decoder layers
|
# decoder layers
|
||||||
all_hidden_states = () if output_hidden_states else None
|
all_hidden_states = () if output_hidden_states else None
|
||||||
all_self_attns = () if output_attentions else None
|
all_self_attns = () if output_attentions else None
|
||||||
|
|
@ -179,11 +186,16 @@ def llama_attention_forward(
|
||||||
self.num_key_value_heads,
|
self.num_key_value_heads,
|
||||||
self.num_key_value_heads], dim=1)
|
self.num_key_value_heads], dim=1)
|
||||||
|
|
||||||
if position_embeddings is None:
|
if isinstance(position_embeddings, torch.Tensor):
|
||||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
import xe_addons
|
||||||
|
inv_freq = position_embeddings
|
||||||
|
xe_addons.rotary_half_inplaced(inv_freq, position_ids, query_states, key_states)
|
||||||
else:
|
else:
|
||||||
cos, sin = position_embeddings
|
if position_embeddings is None:
|
||||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||||
|
else:
|
||||||
|
cos, sin = position_embeddings
|
||||||
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||||
|
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
key_states, value_states = past_key_value.update(key_states, value_states,
|
key_states, value_states = past_key_value.update(key_states, value_states,
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue