refactor ot remove old rope usage (#12224)
This commit is contained in:
parent
324bcb057e
commit
9ea694484d
5 changed files with 34 additions and 56 deletions
|
|
@ -25,7 +25,7 @@
|
|||
# https://github.com/huggingface/transformers/blob/v4.34.1/LICENSE
|
||||
|
||||
# ===========================================================================
|
||||
#
|
||||
#
|
||||
# The patching on this file refers to https://huggingface.co/tiiuae/falcon-7b/discussions/17
|
||||
|
||||
|
||||
|
|
@ -295,17 +295,15 @@ class Attention(nn.Module):
|
|||
# query_layer, key_layer = self.maybe_rotary(query_layer, key_layer)
|
||||
past_kv_length = 0 if layer_past is None else layer_past[0].shape[1]
|
||||
|
||||
use_fuse_rope = query_layer.device.type == "xpu"
|
||||
use_fuse_rope = use_fuse_rope and not (self.training and query_layer.requires_grad)
|
||||
if use_fuse_rope:
|
||||
# resize qk to 4D to match apply_rotary_pos_emb_no_cache_xpu's requirements.
|
||||
from ipex_llm.transformers.models.utils import should_use_fuse_rope
|
||||
if should_use_fuse_rope(hidden_states, position_ids, self.training) and \
|
||||
isinstance(self.maybe_rotary, RotaryEmbedding):
|
||||
# resize qk to 4D to match rotary_half_inplaced's requirements.
|
||||
query_layer = query_layer.reshape(batch_size, self.num_heads, q_length, self.head_dim)
|
||||
key_layer = key_layer.reshape(batch_size, self.num_kv, q_length, self.head_dim)
|
||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
|
||||
query_layer, key_layer = apply_rotary_pos_emb_no_cache_xpu(query_layer,
|
||||
key_layer,
|
||||
position_ids,
|
||||
"gpt_neox")
|
||||
import xe_addons
|
||||
xe_addons.rotary_half_inplaced(self.maybe_rotary.inv_freq, position_ids,
|
||||
query_layer, key_layer)
|
||||
query_layer = query_layer.reshape(batch_size * self.num_heads, q_length, self.head_dim)
|
||||
key_layer = key_layer.reshape(batch_size * self.num_kv, q_length, self.head_dim)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@
|
|||
# https://github.com/huggingface/transformers/blob/v4.34.1/LICENSE
|
||||
|
||||
# ===========================================================================
|
||||
#
|
||||
#
|
||||
# The patching on this file refers to https://huggingface.co/tiiuae/falcon-7b/discussions/17
|
||||
|
||||
|
||||
|
|
@ -295,17 +295,15 @@ class Attention(nn.Module):
|
|||
# query_layer, key_layer = self.maybe_rotary(query_layer, key_layer)
|
||||
past_kv_length = 0 if layer_past is None else layer_past[0].shape[1]
|
||||
|
||||
use_fuse_rope = query_layer.device.type == "xpu"
|
||||
use_fuse_rope = use_fuse_rope and not (self.training and query_layer.requires_grad)
|
||||
if use_fuse_rope:
|
||||
# resize qk to 4D to match apply_rotary_pos_emb_no_cache_xpu's requirements.
|
||||
from ipex_llm.transformers.models.utils import should_use_fuse_rope
|
||||
if should_use_fuse_rope(hidden_states, position_ids, self.training) and \
|
||||
isinstance(self.maybe_rotary, RotaryEmbedding):
|
||||
# resize qk to 4D to match rotary_half_inplaced's requirements.
|
||||
query_layer = query_layer.reshape(batch_size, self.num_heads, q_length, self.head_dim)
|
||||
key_layer = key_layer.reshape(batch_size, self.num_kv, q_length, self.head_dim)
|
||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
|
||||
query_layer, key_layer = apply_rotary_pos_emb_no_cache_xpu(query_layer,
|
||||
key_layer,
|
||||
position_ids,
|
||||
"gpt_neox")
|
||||
import xe_addons
|
||||
xe_addons.rotary_half_inplaced(self.maybe_rotary.inv_freq, position_ids,
|
||||
query_layer, key_layer)
|
||||
query_layer = query_layer.reshape(batch_size * self.num_heads, q_length, self.head_dim)
|
||||
key_layer = key_layer.reshape(batch_size * self.num_kv, q_length, self.head_dim)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -35,7 +35,6 @@ import torch
|
|||
from typing import Optional, Tuple
|
||||
import torch.nn.functional as F
|
||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb
|
||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
|
||||
from ipex_llm.transformers.models.llama import repeat_kv
|
||||
from ipex_llm.transformers.models.utils import should_use_fuse_rope
|
||||
from ipex_llm.transformers.models.utils import update_past_key_value
|
||||
|
|
@ -77,10 +76,9 @@ def decilm_attention_forward_4_35_2(
|
|||
kv_seq_len += past_key_value[0].shape[-2]
|
||||
|
||||
if should_use_fuse_rope(hidden_states, position_ids, self.training):
|
||||
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
|
||||
key_states,
|
||||
position_ids,
|
||||
"llama")
|
||||
import xe_addons
|
||||
xe_addons.rotary_half_inplaced(self.maybe_rotary.inv_freq, position_ids,
|
||||
query_states, key_states)
|
||||
else:
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
|
||||
|
|
|
|||
|
|
@ -33,10 +33,10 @@
|
|||
|
||||
import torch
|
||||
from typing import Optional, Tuple
|
||||
from ipex_llm.transformers.models.utils import should_use_fuse_rope
|
||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb
|
||||
from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, \
|
||||
append_kv_cache, is_enough_kv_cache_room_4_31
|
||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
|
||||
|
||||
import os
|
||||
|
||||
|
|
@ -44,14 +44,14 @@ KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH",
|
|||
|
||||
|
||||
def gptneox_attention_forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
attention_mask: torch.FloatTensor,
|
||||
position_ids: torch.LongTensor,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
||||
use_cache: Optional[bool] = False,
|
||||
output_attentions: Optional[bool] = False,
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
attention_mask: torch.FloatTensor,
|
||||
position_ids: torch.LongTensor,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
||||
use_cache: Optional[bool] = False,
|
||||
output_attentions: Optional[bool] = False,
|
||||
):
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
device = hidden_states.device
|
||||
|
|
@ -89,11 +89,12 @@ def gptneox_attention_forward(
|
|||
use_fuse_rope = query.device.type == "xpu"
|
||||
use_fuse_rope = use_fuse_rope and not (self.training and query.requires_grad)
|
||||
|
||||
if use_fuse_rope:
|
||||
query, key = apply_rotary_pos_emb_no_cache_xpu(query_rot,
|
||||
key_rot,
|
||||
position_ids,
|
||||
"gpt_neox")
|
||||
if should_use_fuse_rope(hidden_states, position_ids, self.training):
|
||||
import xe_addons
|
||||
xe_addons.rotary_half_inplaced(self.maybe_rotary.inv_freq, position_ids,
|
||||
query_rot, key_rot)
|
||||
query = query_rot
|
||||
key = key_rot
|
||||
else:
|
||||
cos, sin = self.rotary_emb(value, seq_len=seq_len)
|
||||
query, key = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids, "gpt_neox")
|
||||
|
|
|
|||
|
|
@ -207,23 +207,6 @@ def apply_ipex_rotate_every_two(q, k, cos, sin):
|
|||
torch.ops.torch_ipex.apply_rotary_embedding(k, sin, cos, k)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb_no_cache_xpu(q, k, position_ids, model_family, rope_theta=10000.0):
|
||||
if q.device.type != "xpu":
|
||||
invalidInputError(False,
|
||||
f"only xpu is supported in this function")
|
||||
import xe_addons
|
||||
q_embed = torch.empty(q.shape, dtype=q.dtype, device=q.device)
|
||||
k_embed = torch.empty(k.shape, dtype=k.dtype, device=k.device)
|
||||
if model_family in ["llama", "baichuan", "internlm", "aquila", "gpt_neox", "mistral",
|
||||
"mixtral"]:
|
||||
xe_addons.apply_rotary_embedding_half_q_and_k(q, k, position_ids,
|
||||
q_embed, k_embed, rope_theta)
|
||||
return q_embed, k_embed
|
||||
else:
|
||||
invalidInputError(False,
|
||||
f"{model_family} is not supported.")
|
||||
|
||||
|
||||
def apply_rotary_pos_emb_cache_freq_xpu(q, k, sin, cos, model_family, position_ids=None):
|
||||
if q.device.type != "xpu":
|
||||
invalidInputError(False,
|
||||
|
|
|
|||
Loading…
Reference in a new issue