using bigdl-llm fused rope for llama (#9066)
* optimize llama xpu rope * fix bug * fix style * refine append cache * remove check * do not cache cos sin * remove unnecessary changes * clean up * fix style * check for training
This commit is contained in:
		
							parent
							
								
									50044640c0
								
							
						
					
					
						commit
						fcb1c618a0
					
				
					 2 changed files with 27 additions and 4 deletions
				
			
		| 
						 | 
				
			
			@ -39,6 +39,7 @@ import torch.nn.functional as F
 | 
			
		|||
from bigdl.llm.utils.common import invalidInputError
 | 
			
		||||
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
 | 
			
		||||
from bigdl.llm.transformers.models.utils import rotate_half, apply_rotary_pos_emb
 | 
			
		||||
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
 | 
			
		||||
| 
						 | 
				
			
			@ -58,7 +59,7 @@ KV_CACHE_ALLOC_BLOCK_LENGTH = 256
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
def llama_rms_norm_forward(self, hidden_states):
 | 
			
		||||
    if hidden_states.device.type == "xpu":
 | 
			
		||||
    if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
 | 
			
		||||
        hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states,
 | 
			
		||||
                                                         [self.weight.size(0)], self.weight)
 | 
			
		||||
    else:
 | 
			
		||||
| 
						 | 
				
			
			@ -116,9 +117,16 @@ def llama_attention_forward_4_31(
 | 
			
		|||
    kv_seq_len = key_states.shape[-2]
 | 
			
		||||
    if past_key_value is not None:
 | 
			
		||||
        kv_seq_len += past_key_value[0].shape[-2]
 | 
			
		||||
    cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
 | 
			
		||||
    query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
 | 
			
		||||
                                                    cos, sin, position_ids, "llama")
 | 
			
		||||
 | 
			
		||||
    if query_states.device.type == "xpu" and not (self.training and query_states.requires_grad):
 | 
			
		||||
        query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
 | 
			
		||||
                                                                     key_states,
 | 
			
		||||
                                                                     position_ids,
 | 
			
		||||
                                                                     "llama")
 | 
			
		||||
    else:
 | 
			
		||||
        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
 | 
			
		||||
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
 | 
			
		||||
                                                        cos, sin, position_ids, "llama")
 | 
			
		||||
 | 
			
		||||
    if past_key_value is not None:
 | 
			
		||||
        # reuse k, v, self_attention
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -97,3 +97,18 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, model_family):
 | 
			
		|||
    else:
 | 
			
		||||
        invalidInputError(False,
 | 
			
		||||
                          f"{model_family} is not supported.")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def apply_rotary_pos_emb_no_cache_xpu(q, k, position_ids, model_family):
 | 
			
		||||
    if q.device.type != "xpu":
 | 
			
		||||
        invalidInputError(False,
 | 
			
		||||
                          f"only xpu is supported in this function")
 | 
			
		||||
    import linear_q4_0
 | 
			
		||||
    q_embed = torch.empty(q.shape, dtype=q.dtype, device=q.device)
 | 
			
		||||
    k_embed = torch.empty(k.shape, dtype=k.dtype, device=k.device)
 | 
			
		||||
    if model_family in ["llama", "baichuan", "internlm", "aquila", "gpt_neox"]:
 | 
			
		||||
        linear_q4_0.apply_rotary_embedding_half_qk(q, k, position_ids, q_embed, k_embed)
 | 
			
		||||
        return q_embed, k_embed
 | 
			
		||||
    else:
 | 
			
		||||
        invalidInputError(False,
 | 
			
		||||
                          f"{model_family} is not supported.")
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue