refactor ot remove old rope usage (#12224)
This commit is contained in:
parent
324bcb057e
commit
9ea694484d
5 changed files with 34 additions and 56 deletions
|
|
@ -25,7 +25,7 @@
|
||||||
# https://github.com/huggingface/transformers/blob/v4.34.1/LICENSE
|
# https://github.com/huggingface/transformers/blob/v4.34.1/LICENSE
|
||||||
|
|
||||||
# ===========================================================================
|
# ===========================================================================
|
||||||
#
|
#
|
||||||
# The patching on this file refers to https://huggingface.co/tiiuae/falcon-7b/discussions/17
|
# The patching on this file refers to https://huggingface.co/tiiuae/falcon-7b/discussions/17
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -295,17 +295,15 @@ class Attention(nn.Module):
|
||||||
# query_layer, key_layer = self.maybe_rotary(query_layer, key_layer)
|
# query_layer, key_layer = self.maybe_rotary(query_layer, key_layer)
|
||||||
past_kv_length = 0 if layer_past is None else layer_past[0].shape[1]
|
past_kv_length = 0 if layer_past is None else layer_past[0].shape[1]
|
||||||
|
|
||||||
use_fuse_rope = query_layer.device.type == "xpu"
|
from ipex_llm.transformers.models.utils import should_use_fuse_rope
|
||||||
use_fuse_rope = use_fuse_rope and not (self.training and query_layer.requires_grad)
|
if should_use_fuse_rope(hidden_states, position_ids, self.training) and \
|
||||||
if use_fuse_rope:
|
isinstance(self.maybe_rotary, RotaryEmbedding):
|
||||||
# resize qk to 4D to match apply_rotary_pos_emb_no_cache_xpu's requirements.
|
# resize qk to 4D to match rotary_half_inplaced's requirements.
|
||||||
query_layer = query_layer.reshape(batch_size, self.num_heads, q_length, self.head_dim)
|
query_layer = query_layer.reshape(batch_size, self.num_heads, q_length, self.head_dim)
|
||||||
key_layer = key_layer.reshape(batch_size, self.num_kv, q_length, self.head_dim)
|
key_layer = key_layer.reshape(batch_size, self.num_kv, q_length, self.head_dim)
|
||||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
|
import xe_addons
|
||||||
query_layer, key_layer = apply_rotary_pos_emb_no_cache_xpu(query_layer,
|
xe_addons.rotary_half_inplaced(self.maybe_rotary.inv_freq, position_ids,
|
||||||
key_layer,
|
query_layer, key_layer)
|
||||||
position_ids,
|
|
||||||
"gpt_neox")
|
|
||||||
query_layer = query_layer.reshape(batch_size * self.num_heads, q_length, self.head_dim)
|
query_layer = query_layer.reshape(batch_size * self.num_heads, q_length, self.head_dim)
|
||||||
key_layer = key_layer.reshape(batch_size * self.num_kv, q_length, self.head_dim)
|
key_layer = key_layer.reshape(batch_size * self.num_kv, q_length, self.head_dim)
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -25,7 +25,7 @@
|
||||||
# https://github.com/huggingface/transformers/blob/v4.34.1/LICENSE
|
# https://github.com/huggingface/transformers/blob/v4.34.1/LICENSE
|
||||||
|
|
||||||
# ===========================================================================
|
# ===========================================================================
|
||||||
#
|
#
|
||||||
# The patching on this file refers to https://huggingface.co/tiiuae/falcon-7b/discussions/17
|
# The patching on this file refers to https://huggingface.co/tiiuae/falcon-7b/discussions/17
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -295,17 +295,15 @@ class Attention(nn.Module):
|
||||||
# query_layer, key_layer = self.maybe_rotary(query_layer, key_layer)
|
# query_layer, key_layer = self.maybe_rotary(query_layer, key_layer)
|
||||||
past_kv_length = 0 if layer_past is None else layer_past[0].shape[1]
|
past_kv_length = 0 if layer_past is None else layer_past[0].shape[1]
|
||||||
|
|
||||||
use_fuse_rope = query_layer.device.type == "xpu"
|
from ipex_llm.transformers.models.utils import should_use_fuse_rope
|
||||||
use_fuse_rope = use_fuse_rope and not (self.training and query_layer.requires_grad)
|
if should_use_fuse_rope(hidden_states, position_ids, self.training) and \
|
||||||
if use_fuse_rope:
|
isinstance(self.maybe_rotary, RotaryEmbedding):
|
||||||
# resize qk to 4D to match apply_rotary_pos_emb_no_cache_xpu's requirements.
|
# resize qk to 4D to match rotary_half_inplaced's requirements.
|
||||||
query_layer = query_layer.reshape(batch_size, self.num_heads, q_length, self.head_dim)
|
query_layer = query_layer.reshape(batch_size, self.num_heads, q_length, self.head_dim)
|
||||||
key_layer = key_layer.reshape(batch_size, self.num_kv, q_length, self.head_dim)
|
key_layer = key_layer.reshape(batch_size, self.num_kv, q_length, self.head_dim)
|
||||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
|
import xe_addons
|
||||||
query_layer, key_layer = apply_rotary_pos_emb_no_cache_xpu(query_layer,
|
xe_addons.rotary_half_inplaced(self.maybe_rotary.inv_freq, position_ids,
|
||||||
key_layer,
|
query_layer, key_layer)
|
||||||
position_ids,
|
|
||||||
"gpt_neox")
|
|
||||||
query_layer = query_layer.reshape(batch_size * self.num_heads, q_length, self.head_dim)
|
query_layer = query_layer.reshape(batch_size * self.num_heads, q_length, self.head_dim)
|
||||||
key_layer = key_layer.reshape(batch_size * self.num_kv, q_length, self.head_dim)
|
key_layer = key_layer.reshape(batch_size * self.num_kv, q_length, self.head_dim)
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -35,7 +35,6 @@ import torch
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb
|
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb
|
||||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
|
|
||||||
from ipex_llm.transformers.models.llama import repeat_kv
|
from ipex_llm.transformers.models.llama import repeat_kv
|
||||||
from ipex_llm.transformers.models.utils import should_use_fuse_rope
|
from ipex_llm.transformers.models.utils import should_use_fuse_rope
|
||||||
from ipex_llm.transformers.models.utils import update_past_key_value
|
from ipex_llm.transformers.models.utils import update_past_key_value
|
||||||
|
|
@ -77,10 +76,9 @@ def decilm_attention_forward_4_35_2(
|
||||||
kv_seq_len += past_key_value[0].shape[-2]
|
kv_seq_len += past_key_value[0].shape[-2]
|
||||||
|
|
||||||
if should_use_fuse_rope(hidden_states, position_ids, self.training):
|
if should_use_fuse_rope(hidden_states, position_ids, self.training):
|
||||||
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
|
import xe_addons
|
||||||
key_states,
|
xe_addons.rotary_half_inplaced(self.maybe_rotary.inv_freq, position_ids,
|
||||||
position_ids,
|
query_states, key_states)
|
||||||
"llama")
|
|
||||||
else:
|
else:
|
||||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
|
||||||
|
|
|
||||||
|
|
@ -33,10 +33,10 @@
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
from ipex_llm.transformers.models.utils import should_use_fuse_rope
|
||||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb
|
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb
|
||||||
from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, \
|
from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, \
|
||||||
append_kv_cache, is_enough_kv_cache_room_4_31
|
append_kv_cache, is_enough_kv_cache_room_4_31
|
||||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
|
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
|
@ -44,14 +44,14 @@ KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH",
|
||||||
|
|
||||||
|
|
||||||
def gptneox_attention_forward(
|
def gptneox_attention_forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.FloatTensor,
|
hidden_states: torch.FloatTensor,
|
||||||
attention_mask: torch.FloatTensor,
|
attention_mask: torch.FloatTensor,
|
||||||
position_ids: torch.LongTensor,
|
position_ids: torch.LongTensor,
|
||||||
head_mask: Optional[torch.FloatTensor] = None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
||||||
use_cache: Optional[bool] = False,
|
use_cache: Optional[bool] = False,
|
||||||
output_attentions: Optional[bool] = False,
|
output_attentions: Optional[bool] = False,
|
||||||
):
|
):
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
device = hidden_states.device
|
device = hidden_states.device
|
||||||
|
|
@ -89,11 +89,12 @@ def gptneox_attention_forward(
|
||||||
use_fuse_rope = query.device.type == "xpu"
|
use_fuse_rope = query.device.type == "xpu"
|
||||||
use_fuse_rope = use_fuse_rope and not (self.training and query.requires_grad)
|
use_fuse_rope = use_fuse_rope and not (self.training and query.requires_grad)
|
||||||
|
|
||||||
if use_fuse_rope:
|
if should_use_fuse_rope(hidden_states, position_ids, self.training):
|
||||||
query, key = apply_rotary_pos_emb_no_cache_xpu(query_rot,
|
import xe_addons
|
||||||
key_rot,
|
xe_addons.rotary_half_inplaced(self.maybe_rotary.inv_freq, position_ids,
|
||||||
position_ids,
|
query_rot, key_rot)
|
||||||
"gpt_neox")
|
query = query_rot
|
||||||
|
key = key_rot
|
||||||
else:
|
else:
|
||||||
cos, sin = self.rotary_emb(value, seq_len=seq_len)
|
cos, sin = self.rotary_emb(value, seq_len=seq_len)
|
||||||
query, key = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids, "gpt_neox")
|
query, key = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids, "gpt_neox")
|
||||||
|
|
|
||||||
|
|
@ -207,23 +207,6 @@ def apply_ipex_rotate_every_two(q, k, cos, sin):
|
||||||
torch.ops.torch_ipex.apply_rotary_embedding(k, sin, cos, k)
|
torch.ops.torch_ipex.apply_rotary_embedding(k, sin, cos, k)
|
||||||
|
|
||||||
|
|
||||||
def apply_rotary_pos_emb_no_cache_xpu(q, k, position_ids, model_family, rope_theta=10000.0):
|
|
||||||
if q.device.type != "xpu":
|
|
||||||
invalidInputError(False,
|
|
||||||
f"only xpu is supported in this function")
|
|
||||||
import xe_addons
|
|
||||||
q_embed = torch.empty(q.shape, dtype=q.dtype, device=q.device)
|
|
||||||
k_embed = torch.empty(k.shape, dtype=k.dtype, device=k.device)
|
|
||||||
if model_family in ["llama", "baichuan", "internlm", "aquila", "gpt_neox", "mistral",
|
|
||||||
"mixtral"]:
|
|
||||||
xe_addons.apply_rotary_embedding_half_q_and_k(q, k, position_ids,
|
|
||||||
q_embed, k_embed, rope_theta)
|
|
||||||
return q_embed, k_embed
|
|
||||||
else:
|
|
||||||
invalidInputError(False,
|
|
||||||
f"{model_family} is not supported.")
|
|
||||||
|
|
||||||
|
|
||||||
def apply_rotary_pos_emb_cache_freq_xpu(q, k, sin, cos, model_family, position_ids=None):
|
def apply_rotary_pos_emb_cache_freq_xpu(q, k, sin, cos, model_family, position_ids=None):
|
||||||
if q.device.type != "xpu":
|
if q.device.type != "xpu":
|
||||||
invalidInputError(False,
|
invalidInputError(False,
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue