add fused rotary pos emb for qwen (#9956)
* add fused rotary pos emb for qwen * update
This commit is contained in:
		
							parent
							
								
									7b1d9ad7c0
								
							
						
					
					
						commit
						dacf680294
					
				
					 2 changed files with 37 additions and 10 deletions
				
			
		| 
						 | 
				
			
			@ -41,6 +41,7 @@ from bigdl.llm.transformers.models.utils import init_fp8_kv_cache, extend_fp8_kv
 | 
			
		|||
    append_fp8_kv_cache, restore_fp8_kv_cache
 | 
			
		||||
from bigdl.llm.transformers.models.utils import rotate_half, quantize_kv_cache
 | 
			
		||||
from bigdl.llm.transformers.models.utils import mlp_fusion_check
 | 
			
		||||
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_cache_freq_xpu
 | 
			
		||||
from bigdl.llm.utils.common import invalidInputError, invalidOperationError
 | 
			
		||||
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -89,25 +90,36 @@ def qwen_attention_forward(
 | 
			
		|||
    # query, key, value's shape: [bs, seq_len, num_heads, head_dim]
 | 
			
		||||
 | 
			
		||||
    if rotary_pos_emb_list is not None:
 | 
			
		||||
        use_fuse_rope = query.device.type == "xpu" and not (self.training and query.requires_grad)
 | 
			
		||||
        cur_len = query.shape[1]
 | 
			
		||||
        if len(rotary_pos_emb_list) == 1:
 | 
			
		||||
            rotary_pos_emb = rotary_pos_emb_list[0]
 | 
			
		||||
            rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb]
 | 
			
		||||
            rotary_pos_emb = (rotary_pos_emb,) * 2
 | 
			
		||||
            q_pos_emb, k_pos_emb = rotary_pos_emb
 | 
			
		||||
            # Slice the pos emb for current inference
 | 
			
		||||
            query = apply_rotary_pos_emb(query, q_pos_emb)
 | 
			
		||||
            key = apply_rotary_pos_emb(key, k_pos_emb)
 | 
			
		||||
            if use_fuse_rope:
 | 
			
		||||
                cos, sin = rotary_pos_emb
 | 
			
		||||
                query, key = apply_rotary_pos_emb_cache_freq_xpu(query, key, sin, cos, "qwen")
 | 
			
		||||
            else:
 | 
			
		||||
                rotary_pos_emb = (rotary_pos_emb,) * 2
 | 
			
		||||
                q_pos_emb, k_pos_emb = rotary_pos_emb
 | 
			
		||||
                # Slice the pos emb for current inference
 | 
			
		||||
                query = apply_rotary_pos_emb(query, q_pos_emb)
 | 
			
		||||
                key = apply_rotary_pos_emb(key, k_pos_emb)
 | 
			
		||||
        else:
 | 
			
		||||
            query_list = []
 | 
			
		||||
            key_list = []
 | 
			
		||||
            for i, rotary_pos_emb in enumerate(rotary_pos_emb_list):
 | 
			
		||||
                rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb]
 | 
			
		||||
                rotary_pos_emb = (rotary_pos_emb,) * 2
 | 
			
		||||
                q_pos_emb, k_pos_emb = rotary_pos_emb
 | 
			
		||||
                # Slice the pos emb for current inference
 | 
			
		||||
                query_list += [apply_rotary_pos_emb(query[i:i+1, :, :], q_pos_emb)]
 | 
			
		||||
                key_list += [apply_rotary_pos_emb(key[i:i+1, :, :], k_pos_emb)]
 | 
			
		||||
                if use_fuse_rope:
 | 
			
		||||
                    cos, sin = rotary_pos_emb
 | 
			
		||||
                    query, key = apply_rotary_pos_emb_cache_freq_xpu(query, key, sin, cos, "qwen")
 | 
			
		||||
                    query_list += [query]
 | 
			
		||||
                    key_list += [key]
 | 
			
		||||
                else:
 | 
			
		||||
                    rotary_pos_emb = (rotary_pos_emb,) * 2
 | 
			
		||||
                    q_pos_emb, k_pos_emb = rotary_pos_emb
 | 
			
		||||
                    # Slice the pos emb for current inference
 | 
			
		||||
                    query_list += [apply_rotary_pos_emb(query[i:i+1, :, :], q_pos_emb)]
 | 
			
		||||
                    key_list += [apply_rotary_pos_emb(key[i:i+1, :, :], k_pos_emb)]
 | 
			
		||||
            query = torch.cat(query_list, dim=0)
 | 
			
		||||
            key = torch.cat(key_list, dim=0)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -177,6 +177,21 @@ def apply_rotary_pos_emb_no_cache_xpu(q, k, position_ids, model_family):
 | 
			
		|||
                          f"{model_family} is not supported.")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def apply_rotary_pos_emb_cache_freq_xpu(q, k, sin, cos, model_family):
 | 
			
		||||
    if q.device.type != "xpu":
 | 
			
		||||
        invalidInputError(False,
 | 
			
		||||
                          f"only xpu is supported in this function")
 | 
			
		||||
    import linear_q4_0
 | 
			
		||||
    q_embed = torch.empty(q.shape, dtype=q.dtype, device=q.device)
 | 
			
		||||
    k_embed = torch.empty(k.shape, dtype=k.dtype, device=k.device)
 | 
			
		||||
    if model_family in ["qwen"]:
 | 
			
		||||
        linear_q4_0.apply_rotary_embedding_half_q_and_k_cache_freq(q, k, sin, cos, q_embed, k_embed)
 | 
			
		||||
        return q_embed, k_embed
 | 
			
		||||
    else:
 | 
			
		||||
        invalidInputError(False,
 | 
			
		||||
                          f"{model_family} is not supported.")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def is_enough_kv_cache_room_4_36(past_key_value, idx, seq_len=1):
 | 
			
		||||
    # to determinate if is enough kv cache room in transformers==4.36
 | 
			
		||||
    return past_key_value is not None and len(past_key_value.key_cache) > idx and \
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue