parent
							
								
									a54cd767b1
								
							
						
					
					
						commit
						20e9742fa0
					
				
					 3 changed files with 62 additions and 57 deletions
				
			
		| 
						 | 
				
			
			@ -23,6 +23,7 @@ from typing import Optional, Tuple, List
 | 
			
		|||
import torch.nn.functional as F
 | 
			
		||||
from transformers.modeling_outputs import BaseModelOutputWithPast
 | 
			
		||||
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
 | 
			
		||||
from bigdl.llm.transformers.models.utils import use_flash_attention
 | 
			
		||||
from bigdl.llm.transformers.models.llama import get_ipex_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -365,11 +366,10 @@ def core_attn_forward_8eb45c(self, query_layer, key_layer, value_layer, attentio
 | 
			
		|||
    pytorch_major_version = int(torch.__version__.split('.')[0])
 | 
			
		||||
    if pytorch_major_version >= 2 and (query_layer.device.type == 'xpu' or query_layer.size(0) > 1):
 | 
			
		||||
        query_layer = query_layer.permute(1, 2, 0, 3)
 | 
			
		||||
        if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
 | 
			
		||||
        if attention_mask is None and use_flash_attention(query_layer):
 | 
			
		||||
            context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer,
 | 
			
		||||
                                                                             key_layer,
 | 
			
		||||
                                                                             value_layer,
 | 
			
		||||
                                                                             attention_mask,
 | 
			
		||||
                                                                             is_causal=True)
 | 
			
		||||
        elif attention_mask is None:
 | 
			
		||||
            scaling_factor = 1 / math.sqrt(query_layer.size(-1))
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -42,6 +42,7 @@ from bigdl.llm.utils.common import invalidInputError
 | 
			
		|||
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
 | 
			
		||||
from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_31, apply_rotary_pos_emb
 | 
			
		||||
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
 | 
			
		||||
from bigdl.llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
 | 
			
		||||
from transformers.modeling_outputs import BaseModelOutputWithPast
 | 
			
		||||
from bigdl.llm.transformers.low_bit_linear import SYM_INT4
 | 
			
		||||
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
 | 
			
		||||
| 
						 | 
				
			
			@ -508,61 +509,6 @@ def llama_attention_selective_batching_forward_4_31(
 | 
			
		|||
    return attn_output.to(original_dtype), attn_weights, updated_past_key_values
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def use_flash_attention(query):
 | 
			
		||||
    bsz, q_len, _ = query.size()
 | 
			
		||||
    # check whether ipex flash attention can be used
 | 
			
		||||
    if bsz > 1:
 | 
			
		||||
        # only use flash attention for batch_size = 1 now
 | 
			
		||||
        # as flash attention doesn't support attn_mask in ipex 2.1,
 | 
			
		||||
        # so it will cause output error for padded batch input
 | 
			
		||||
        return False
 | 
			
		||||
    if q_len == 1:
 | 
			
		||||
        # now only use flash attention for first token
 | 
			
		||||
        # as it seems have no performance benifit for rest token now
 | 
			
		||||
        return False
 | 
			
		||||
    if query.device.type != "xpu":
 | 
			
		||||
        # ipex flash attention only support for xpu
 | 
			
		||||
        return False
 | 
			
		||||
    ipex_version = get_ipex_version()
 | 
			
		||||
    if ipex_version <= "2.0.110+xpu":
 | 
			
		||||
        # ipex flash attention is supported from ipex 2.1
 | 
			
		||||
        return False
 | 
			
		||||
    if not torch.xpu.has_xetla():
 | 
			
		||||
        # ipex flash attention is only supported for xetla
 | 
			
		||||
        # may update this later
 | 
			
		||||
        return False
 | 
			
		||||
    if query.dtype not in [torch.float32, torch.float16]:
 | 
			
		||||
        # only use flash attention for fp32/fp16 input
 | 
			
		||||
        return False
 | 
			
		||||
    return True
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def use_esimd_sdp(q_len, head_dim, query_states):
 | 
			
		||||
    if head_dim != 128:
 | 
			
		||||
        # esimd_sdp only support head_dim = 128 now
 | 
			
		||||
        return False
 | 
			
		||||
    elif q_len != 1:
 | 
			
		||||
        # esimd_sdp only support rest token now
 | 
			
		||||
        return False
 | 
			
		||||
    elif query_states.device.type != "xpu":
 | 
			
		||||
        # esimd_sdp only support GPU now
 | 
			
		||||
        return False
 | 
			
		||||
    elif query_states.dtype != torch.float16:
 | 
			
		||||
        # esimd_sdp only has optimization for FP16 now
 | 
			
		||||
        return False
 | 
			
		||||
    else:
 | 
			
		||||
        device_name = torch.xpu.get_device_name(query_states.device.index)
 | 
			
		||||
        if device_name.startswith("Intel(R) Arc(TM) A") or \
 | 
			
		||||
                device_name.startswith("Intel(R) Data Center GPU Flex"):
 | 
			
		||||
            import linear_fp16_esimd
 | 
			
		||||
            if hasattr(linear_fp16_esimd, "sdp_forward"):
 | 
			
		||||
                return True
 | 
			
		||||
            else:
 | 
			
		||||
                return False
 | 
			
		||||
        else:
 | 
			
		||||
            return False
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def native_sdp(query, key, value, attention_mask,
 | 
			
		||||
               bsz, q_len, kv_seq_len, head_dim, num_heads):
 | 
			
		||||
    attn_weights = torch.matmul(query,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -16,6 +16,7 @@
 | 
			
		|||
 | 
			
		||||
import torch
 | 
			
		||||
from bigdl.llm.utils.common import invalidInputError
 | 
			
		||||
from bigdl.llm.transformers.utils import get_ipex_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def init_kv_cache(batch_size, num_heads, head_dim, current_length, max_length, dtype, device):
 | 
			
		||||
| 
						 | 
				
			
			@ -119,3 +120,61 @@ def is_enough_kv_cache_room_4_31(past_key_value):
 | 
			
		|||
    # to determinate if is enough kv cache room in transformers between 4.31 and 4.35
 | 
			
		||||
    return past_key_value is not None and \
 | 
			
		||||
        past_key_value[0].stride()[1] > past_key_value[0].size(2) * past_key_value[0].size(3)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def use_flash_attention(query):
 | 
			
		||||
    if query.dim() == 3:
 | 
			
		||||
        bsz, q_len, _ = query.size()
 | 
			
		||||
    elif query.dim() == 4:
 | 
			
		||||
        bsz, _, q_len, _ = query.size()
 | 
			
		||||
    # check whether ipex flash attention can be used
 | 
			
		||||
    if bsz > 1:
 | 
			
		||||
        # only use flash attention for batch_size = 1 now
 | 
			
		||||
        # as flash attention doesn't support attn_mask in ipex 2.1,
 | 
			
		||||
        # so it will cause output error for padded batch input
 | 
			
		||||
        return False
 | 
			
		||||
    if q_len == 1:
 | 
			
		||||
        # now only use flash attention for first token
 | 
			
		||||
        # as it seems have no performance benifit for rest token now
 | 
			
		||||
        return False
 | 
			
		||||
    if query.device.type != "xpu":
 | 
			
		||||
        # ipex flash attention only support for xpu
 | 
			
		||||
        return False
 | 
			
		||||
    ipex_version = get_ipex_version()
 | 
			
		||||
    if ipex_version <= "2.0.110+xpu":
 | 
			
		||||
        # ipex flash attention is supported from ipex 2.1
 | 
			
		||||
        return False
 | 
			
		||||
    if not torch.xpu.has_xetla():
 | 
			
		||||
        # ipex flash attention is only supported for xetla
 | 
			
		||||
        # may update this later
 | 
			
		||||
        return False
 | 
			
		||||
    if query.dtype not in [torch.float32, torch.float16]:
 | 
			
		||||
        # only use flash attention for fp32/fp16 input
 | 
			
		||||
        return False
 | 
			
		||||
    return True
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def use_esimd_sdp(q_len, head_dim, query_states):
 | 
			
		||||
    if head_dim != 128:
 | 
			
		||||
        # esimd_sdp only support head_dim = 128 now
 | 
			
		||||
        return False
 | 
			
		||||
    elif q_len != 1:
 | 
			
		||||
        # esimd_sdp only support rest token now
 | 
			
		||||
        return False
 | 
			
		||||
    elif query_states.device.type != "xpu":
 | 
			
		||||
        # esimd_sdp only support GPU now
 | 
			
		||||
        return False
 | 
			
		||||
    elif query_states.dtype != torch.float16:
 | 
			
		||||
        # esimd_sdp only has optimization for FP16 now
 | 
			
		||||
        return False
 | 
			
		||||
    else:
 | 
			
		||||
        device_name = torch.xpu.get_device_name(query_states.device.index)
 | 
			
		||||
        if device_name.startswith("Intel(R) Arc(TM) A") or \
 | 
			
		||||
                device_name.startswith("Intel(R) Data Center GPU Flex"):
 | 
			
		||||
            import linear_fp16_esimd
 | 
			
		||||
            if hasattr(linear_fp16_esimd, "sdp_forward"):
 | 
			
		||||
                return True
 | 
			
		||||
            else:
 | 
			
		||||
                return False
 | 
			
		||||
        else:
 | 
			
		||||
            return False
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue