refactor ot remove old rope usage (#12224)
This commit is contained in:
		
							parent
							
								
									324bcb057e
								
							
						
					
					
						commit
						9ea694484d
					
				
					 5 changed files with 34 additions and 56 deletions
				
			
		| 
						 | 
					@ -295,17 +295,15 @@ class Attention(nn.Module):
 | 
				
			||||||
        # query_layer, key_layer = self.maybe_rotary(query_layer, key_layer)
 | 
					        # query_layer, key_layer = self.maybe_rotary(query_layer, key_layer)
 | 
				
			||||||
        past_kv_length = 0 if layer_past is None else layer_past[0].shape[1]
 | 
					        past_kv_length = 0 if layer_past is None else layer_past[0].shape[1]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        use_fuse_rope = query_layer.device.type == "xpu"
 | 
					        from ipex_llm.transformers.models.utils import should_use_fuse_rope
 | 
				
			||||||
        use_fuse_rope = use_fuse_rope and not (self.training and query_layer.requires_grad)
 | 
					        if should_use_fuse_rope(hidden_states, position_ids, self.training) and \
 | 
				
			||||||
        if use_fuse_rope:
 | 
					                isinstance(self.maybe_rotary, RotaryEmbedding):
 | 
				
			||||||
            # resize qk to 4D to match apply_rotary_pos_emb_no_cache_xpu's requirements.
 | 
					            # resize qk to 4D to match rotary_half_inplaced's requirements.
 | 
				
			||||||
            query_layer = query_layer.reshape(batch_size, self.num_heads, q_length, self.head_dim)
 | 
					            query_layer = query_layer.reshape(batch_size, self.num_heads, q_length, self.head_dim)
 | 
				
			||||||
            key_layer = key_layer.reshape(batch_size, self.num_kv, q_length, self.head_dim)
 | 
					            key_layer = key_layer.reshape(batch_size, self.num_kv, q_length, self.head_dim)
 | 
				
			||||||
            from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
 | 
					            import xe_addons
 | 
				
			||||||
            query_layer, key_layer = apply_rotary_pos_emb_no_cache_xpu(query_layer,
 | 
					            xe_addons.rotary_half_inplaced(self.maybe_rotary.inv_freq, position_ids,
 | 
				
			||||||
                                                                       key_layer,
 | 
					                                           query_layer, key_layer)
 | 
				
			||||||
                                                                       position_ids,
 | 
					 | 
				
			||||||
                                                                       "gpt_neox")
 | 
					 | 
				
			||||||
            query_layer = query_layer.reshape(batch_size * self.num_heads, q_length, self.head_dim)
 | 
					            query_layer = query_layer.reshape(batch_size * self.num_heads, q_length, self.head_dim)
 | 
				
			||||||
            key_layer = key_layer.reshape(batch_size * self.num_kv, q_length, self.head_dim)
 | 
					            key_layer = key_layer.reshape(batch_size * self.num_kv, q_length, self.head_dim)
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -295,17 +295,15 @@ class Attention(nn.Module):
 | 
				
			||||||
        # query_layer, key_layer = self.maybe_rotary(query_layer, key_layer)
 | 
					        # query_layer, key_layer = self.maybe_rotary(query_layer, key_layer)
 | 
				
			||||||
        past_kv_length = 0 if layer_past is None else layer_past[0].shape[1]
 | 
					        past_kv_length = 0 if layer_past is None else layer_past[0].shape[1]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        use_fuse_rope = query_layer.device.type == "xpu"
 | 
					        from ipex_llm.transformers.models.utils import should_use_fuse_rope
 | 
				
			||||||
        use_fuse_rope = use_fuse_rope and not (self.training and query_layer.requires_grad)
 | 
					        if should_use_fuse_rope(hidden_states, position_ids, self.training) and \
 | 
				
			||||||
        if use_fuse_rope:
 | 
					                isinstance(self.maybe_rotary, RotaryEmbedding):
 | 
				
			||||||
            # resize qk to 4D to match apply_rotary_pos_emb_no_cache_xpu's requirements.
 | 
					            # resize qk to 4D to match rotary_half_inplaced's requirements.
 | 
				
			||||||
            query_layer = query_layer.reshape(batch_size, self.num_heads, q_length, self.head_dim)
 | 
					            query_layer = query_layer.reshape(batch_size, self.num_heads, q_length, self.head_dim)
 | 
				
			||||||
            key_layer = key_layer.reshape(batch_size, self.num_kv, q_length, self.head_dim)
 | 
					            key_layer = key_layer.reshape(batch_size, self.num_kv, q_length, self.head_dim)
 | 
				
			||||||
            from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
 | 
					            import xe_addons
 | 
				
			||||||
            query_layer, key_layer = apply_rotary_pos_emb_no_cache_xpu(query_layer,
 | 
					            xe_addons.rotary_half_inplaced(self.maybe_rotary.inv_freq, position_ids,
 | 
				
			||||||
                                                                       key_layer,
 | 
					                                           query_layer, key_layer)
 | 
				
			||||||
                                                                       position_ids,
 | 
					 | 
				
			||||||
                                                                       "gpt_neox")
 | 
					 | 
				
			||||||
            query_layer = query_layer.reshape(batch_size * self.num_heads, q_length, self.head_dim)
 | 
					            query_layer = query_layer.reshape(batch_size * self.num_heads, q_length, self.head_dim)
 | 
				
			||||||
            key_layer = key_layer.reshape(batch_size * self.num_kv, q_length, self.head_dim)
 | 
					            key_layer = key_layer.reshape(batch_size * self.num_kv, q_length, self.head_dim)
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -35,7 +35,6 @@ import torch
 | 
				
			||||||
from typing import Optional, Tuple
 | 
					from typing import Optional, Tuple
 | 
				
			||||||
import torch.nn.functional as F
 | 
					import torch.nn.functional as F
 | 
				
			||||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb
 | 
					from ipex_llm.transformers.models.utils import apply_rotary_pos_emb
 | 
				
			||||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
 | 
					 | 
				
			||||||
from ipex_llm.transformers.models.llama import repeat_kv
 | 
					from ipex_llm.transformers.models.llama import repeat_kv
 | 
				
			||||||
from ipex_llm.transformers.models.utils import should_use_fuse_rope
 | 
					from ipex_llm.transformers.models.utils import should_use_fuse_rope
 | 
				
			||||||
from ipex_llm.transformers.models.utils import update_past_key_value
 | 
					from ipex_llm.transformers.models.utils import update_past_key_value
 | 
				
			||||||
| 
						 | 
					@ -77,10 +76,9 @@ def decilm_attention_forward_4_35_2(
 | 
				
			||||||
        kv_seq_len += past_key_value[0].shape[-2]
 | 
					        kv_seq_len += past_key_value[0].shape[-2]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if should_use_fuse_rope(hidden_states, position_ids, self.training):
 | 
					    if should_use_fuse_rope(hidden_states, position_ids, self.training):
 | 
				
			||||||
        query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
 | 
					        import xe_addons
 | 
				
			||||||
                                                                     key_states,
 | 
					        xe_addons.rotary_half_inplaced(self.maybe_rotary.inv_freq, position_ids,
 | 
				
			||||||
                                                                     position_ids,
 | 
					                                       query_states, key_states)
 | 
				
			||||||
                                                                     "llama")
 | 
					 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
 | 
					        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
 | 
				
			||||||
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
 | 
					        query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -33,10 +33,10 @@
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
from typing import Optional, Tuple
 | 
					from typing import Optional, Tuple
 | 
				
			||||||
 | 
					from ipex_llm.transformers.models.utils import should_use_fuse_rope
 | 
				
			||||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb
 | 
					from ipex_llm.transformers.models.utils import apply_rotary_pos_emb
 | 
				
			||||||
from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, \
 | 
					from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, \
 | 
				
			||||||
    append_kv_cache, is_enough_kv_cache_room_4_31
 | 
					    append_kv_cache, is_enough_kv_cache_room_4_31
 | 
				
			||||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
import os
 | 
					import os
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -44,14 +44,14 @@ KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH",
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def gptneox_attention_forward(
 | 
					def gptneox_attention_forward(
 | 
				
			||||||
        self,
 | 
					    self,
 | 
				
			||||||
        hidden_states: torch.FloatTensor,
 | 
					    hidden_states: torch.FloatTensor,
 | 
				
			||||||
        attention_mask: torch.FloatTensor,
 | 
					    attention_mask: torch.FloatTensor,
 | 
				
			||||||
        position_ids: torch.LongTensor,
 | 
					    position_ids: torch.LongTensor,
 | 
				
			||||||
        head_mask: Optional[torch.FloatTensor] = None,
 | 
					    head_mask: Optional[torch.FloatTensor] = None,
 | 
				
			||||||
        layer_past: Optional[Tuple[torch.Tensor]] = None,
 | 
					    layer_past: Optional[Tuple[torch.Tensor]] = None,
 | 
				
			||||||
        use_cache: Optional[bool] = False,
 | 
					    use_cache: Optional[bool] = False,
 | 
				
			||||||
        output_attentions: Optional[bool] = False,
 | 
					    output_attentions: Optional[bool] = False,
 | 
				
			||||||
):
 | 
					):
 | 
				
			||||||
    bsz, q_len, _ = hidden_states.size()
 | 
					    bsz, q_len, _ = hidden_states.size()
 | 
				
			||||||
    device = hidden_states.device
 | 
					    device = hidden_states.device
 | 
				
			||||||
| 
						 | 
					@ -89,11 +89,12 @@ def gptneox_attention_forward(
 | 
				
			||||||
    use_fuse_rope = query.device.type == "xpu"
 | 
					    use_fuse_rope = query.device.type == "xpu"
 | 
				
			||||||
    use_fuse_rope = use_fuse_rope and not (self.training and query.requires_grad)
 | 
					    use_fuse_rope = use_fuse_rope and not (self.training and query.requires_grad)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if use_fuse_rope:
 | 
					    if should_use_fuse_rope(hidden_states, position_ids, self.training):
 | 
				
			||||||
        query, key = apply_rotary_pos_emb_no_cache_xpu(query_rot,
 | 
					        import xe_addons
 | 
				
			||||||
                                                       key_rot,
 | 
					        xe_addons.rotary_half_inplaced(self.maybe_rotary.inv_freq, position_ids,
 | 
				
			||||||
                                                       position_ids,
 | 
					                                       query_rot, key_rot)
 | 
				
			||||||
                                                       "gpt_neox")
 | 
					        query = query_rot
 | 
				
			||||||
 | 
					        key = key_rot
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        cos, sin = self.rotary_emb(value, seq_len=seq_len)
 | 
					        cos, sin = self.rotary_emb(value, seq_len=seq_len)
 | 
				
			||||||
        query, key = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids, "gpt_neox")
 | 
					        query, key = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids, "gpt_neox")
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -207,23 +207,6 @@ def apply_ipex_rotate_every_two(q, k, cos, sin):
 | 
				
			||||||
        torch.ops.torch_ipex.apply_rotary_embedding(k, sin, cos, k)
 | 
					        torch.ops.torch_ipex.apply_rotary_embedding(k, sin, cos, k)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def apply_rotary_pos_emb_no_cache_xpu(q, k, position_ids, model_family, rope_theta=10000.0):
 | 
					 | 
				
			||||||
    if q.device.type != "xpu":
 | 
					 | 
				
			||||||
        invalidInputError(False,
 | 
					 | 
				
			||||||
                          f"only xpu is supported in this function")
 | 
					 | 
				
			||||||
    import xe_addons
 | 
					 | 
				
			||||||
    q_embed = torch.empty(q.shape, dtype=q.dtype, device=q.device)
 | 
					 | 
				
			||||||
    k_embed = torch.empty(k.shape, dtype=k.dtype, device=k.device)
 | 
					 | 
				
			||||||
    if model_family in ["llama", "baichuan", "internlm", "aquila", "gpt_neox", "mistral",
 | 
					 | 
				
			||||||
                        "mixtral"]:
 | 
					 | 
				
			||||||
        xe_addons.apply_rotary_embedding_half_q_and_k(q, k, position_ids,
 | 
					 | 
				
			||||||
                                                      q_embed, k_embed, rope_theta)
 | 
					 | 
				
			||||||
        return q_embed, k_embed
 | 
					 | 
				
			||||||
    else:
 | 
					 | 
				
			||||||
        invalidInputError(False,
 | 
					 | 
				
			||||||
                          f"{model_family} is not supported.")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def apply_rotary_pos_emb_cache_freq_xpu(q, k, sin, cos, model_family, position_ids=None):
 | 
					def apply_rotary_pos_emb_cache_freq_xpu(q, k, sin, cos, model_family, position_ids=None):
 | 
				
			||||||
    if q.device.type != "xpu":
 | 
					    if q.device.type != "xpu":
 | 
				
			||||||
        invalidInputError(False,
 | 
					        invalidInputError(False,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue