refactor ot remove old rope usage (#12224)

This commit is contained in:
Yishuo Wang 2024-10-17 17:06:09 +08:00 committed by GitHub
parent 324bcb057e
commit 9ea694484d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 34 additions and 56 deletions

View file

@ -25,7 +25,7 @@
# https://github.com/huggingface/transformers/blob/v4.34.1/LICENSE
# ===========================================================================
#
#
# The patching on this file refers to https://huggingface.co/tiiuae/falcon-7b/discussions/17
@ -295,17 +295,15 @@ class Attention(nn.Module):
# query_layer, key_layer = self.maybe_rotary(query_layer, key_layer)
past_kv_length = 0 if layer_past is None else layer_past[0].shape[1]
use_fuse_rope = query_layer.device.type == "xpu"
use_fuse_rope = use_fuse_rope and not (self.training and query_layer.requires_grad)
if use_fuse_rope:
# resize qk to 4D to match apply_rotary_pos_emb_no_cache_xpu's requirements.
from ipex_llm.transformers.models.utils import should_use_fuse_rope
if should_use_fuse_rope(hidden_states, position_ids, self.training) and \
isinstance(self.maybe_rotary, RotaryEmbedding):
# resize qk to 4D to match rotary_half_inplaced's requirements.
query_layer = query_layer.reshape(batch_size, self.num_heads, q_length, self.head_dim)
key_layer = key_layer.reshape(batch_size, self.num_kv, q_length, self.head_dim)
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
query_layer, key_layer = apply_rotary_pos_emb_no_cache_xpu(query_layer,
key_layer,
position_ids,
"gpt_neox")
import xe_addons
xe_addons.rotary_half_inplaced(self.maybe_rotary.inv_freq, position_ids,
query_layer, key_layer)
query_layer = query_layer.reshape(batch_size * self.num_heads, q_length, self.head_dim)
key_layer = key_layer.reshape(batch_size * self.num_kv, q_length, self.head_dim)
else:

View file

@ -25,7 +25,7 @@
# https://github.com/huggingface/transformers/blob/v4.34.1/LICENSE
# ===========================================================================
#
#
# The patching on this file refers to https://huggingface.co/tiiuae/falcon-7b/discussions/17
@ -295,17 +295,15 @@ class Attention(nn.Module):
# query_layer, key_layer = self.maybe_rotary(query_layer, key_layer)
past_kv_length = 0 if layer_past is None else layer_past[0].shape[1]
use_fuse_rope = query_layer.device.type == "xpu"
use_fuse_rope = use_fuse_rope and not (self.training and query_layer.requires_grad)
if use_fuse_rope:
# resize qk to 4D to match apply_rotary_pos_emb_no_cache_xpu's requirements.
from ipex_llm.transformers.models.utils import should_use_fuse_rope
if should_use_fuse_rope(hidden_states, position_ids, self.training) and \
isinstance(self.maybe_rotary, RotaryEmbedding):
# resize qk to 4D to match rotary_half_inplaced's requirements.
query_layer = query_layer.reshape(batch_size, self.num_heads, q_length, self.head_dim)
key_layer = key_layer.reshape(batch_size, self.num_kv, q_length, self.head_dim)
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
query_layer, key_layer = apply_rotary_pos_emb_no_cache_xpu(query_layer,
key_layer,
position_ids,
"gpt_neox")
import xe_addons
xe_addons.rotary_half_inplaced(self.maybe_rotary.inv_freq, position_ids,
query_layer, key_layer)
query_layer = query_layer.reshape(batch_size * self.num_heads, q_length, self.head_dim)
key_layer = key_layer.reshape(batch_size * self.num_kv, q_length, self.head_dim)
else:

View file

@ -35,7 +35,6 @@ import torch
from typing import Optional, Tuple
import torch.nn.functional as F
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
from ipex_llm.transformers.models.llama import repeat_kv
from ipex_llm.transformers.models.utils import should_use_fuse_rope
from ipex_llm.transformers.models.utils import update_past_key_value
@ -77,10 +76,9 @@ def decilm_attention_forward_4_35_2(
kv_seq_len += past_key_value[0].shape[-2]
if should_use_fuse_rope(hidden_states, position_ids, self.training):
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
key_states,
position_ids,
"llama")
import xe_addons
xe_addons.rotary_half_inplaced(self.maybe_rotary.inv_freq, position_ids,
query_states, key_states)
else:
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,

View file

@ -33,10 +33,10 @@
import torch
from typing import Optional, Tuple
from ipex_llm.transformers.models.utils import should_use_fuse_rope
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb
from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, \
append_kv_cache, is_enough_kv_cache_room_4_31
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
import os
@ -44,14 +44,14 @@ KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH",
def gptneox_attention_forward(
self,
hidden_states: torch.FloatTensor,
attention_mask: torch.FloatTensor,
position_ids: torch.LongTensor,
head_mask: Optional[torch.FloatTensor] = None,
layer_past: Optional[Tuple[torch.Tensor]] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
self,
hidden_states: torch.FloatTensor,
attention_mask: torch.FloatTensor,
position_ids: torch.LongTensor,
head_mask: Optional[torch.FloatTensor] = None,
layer_past: Optional[Tuple[torch.Tensor]] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
):
bsz, q_len, _ = hidden_states.size()
device = hidden_states.device
@ -89,11 +89,12 @@ def gptneox_attention_forward(
use_fuse_rope = query.device.type == "xpu"
use_fuse_rope = use_fuse_rope and not (self.training and query.requires_grad)
if use_fuse_rope:
query, key = apply_rotary_pos_emb_no_cache_xpu(query_rot,
key_rot,
position_ids,
"gpt_neox")
if should_use_fuse_rope(hidden_states, position_ids, self.training):
import xe_addons
xe_addons.rotary_half_inplaced(self.maybe_rotary.inv_freq, position_ids,
query_rot, key_rot)
query = query_rot
key = key_rot
else:
cos, sin = self.rotary_emb(value, seq_len=seq_len)
query, key = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids, "gpt_neox")

View file

@ -207,23 +207,6 @@ def apply_ipex_rotate_every_two(q, k, cos, sin):
torch.ops.torch_ipex.apply_rotary_embedding(k, sin, cos, k)
def apply_rotary_pos_emb_no_cache_xpu(q, k, position_ids, model_family, rope_theta=10000.0):
if q.device.type != "xpu":
invalidInputError(False,
f"only xpu is supported in this function")
import xe_addons
q_embed = torch.empty(q.shape, dtype=q.dtype, device=q.device)
k_embed = torch.empty(k.shape, dtype=k.dtype, device=k.device)
if model_family in ["llama", "baichuan", "internlm", "aquila", "gpt_neox", "mistral",
"mixtral"]:
xe_addons.apply_rotary_embedding_half_q_and_k(q, k, position_ids,
q_embed, k_embed, rope_theta)
return q_embed, k_embed
else:
invalidInputError(False,
f"{model_family} is not supported.")
def apply_rotary_pos_emb_cache_freq_xpu(q, k, sin, cos, model_family, position_ids=None):
if q.device.type != "xpu":
invalidInputError(False,