parent
a54cd767b1
commit
20e9742fa0
3 changed files with 62 additions and 57 deletions
|
|
@ -23,6 +23,7 @@ from typing import Optional, Tuple, List
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||||
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
||||||
|
from bigdl.llm.transformers.models.utils import use_flash_attention
|
||||||
from bigdl.llm.transformers.models.llama import get_ipex_version
|
from bigdl.llm.transformers.models.llama import get_ipex_version
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -365,11 +366,10 @@ def core_attn_forward_8eb45c(self, query_layer, key_layer, value_layer, attentio
|
||||||
pytorch_major_version = int(torch.__version__.split('.')[0])
|
pytorch_major_version = int(torch.__version__.split('.')[0])
|
||||||
if pytorch_major_version >= 2 and (query_layer.device.type == 'xpu' or query_layer.size(0) > 1):
|
if pytorch_major_version >= 2 and (query_layer.device.type == 'xpu' or query_layer.size(0) > 1):
|
||||||
query_layer = query_layer.permute(1, 2, 0, 3)
|
query_layer = query_layer.permute(1, 2, 0, 3)
|
||||||
if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
|
if attention_mask is None and use_flash_attention(query_layer):
|
||||||
context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer,
|
context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer,
|
||||||
key_layer,
|
key_layer,
|
||||||
value_layer,
|
value_layer,
|
||||||
attention_mask,
|
|
||||||
is_causal=True)
|
is_causal=True)
|
||||||
elif attention_mask is None:
|
elif attention_mask is None:
|
||||||
scaling_factor = 1 / math.sqrt(query_layer.size(-1))
|
scaling_factor = 1 / math.sqrt(query_layer.size(-1))
|
||||||
|
|
|
||||||
|
|
@ -42,6 +42,7 @@ from bigdl.llm.utils.common import invalidInputError
|
||||||
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
||||||
from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_31, apply_rotary_pos_emb
|
from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_31, apply_rotary_pos_emb
|
||||||
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
|
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
|
||||||
|
from bigdl.llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
|
||||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||||
from bigdl.llm.transformers.low_bit_linear import SYM_INT4
|
from bigdl.llm.transformers.low_bit_linear import SYM_INT4
|
||||||
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
||||||
|
|
@ -508,61 +509,6 @@ def llama_attention_selective_batching_forward_4_31(
|
||||||
return attn_output.to(original_dtype), attn_weights, updated_past_key_values
|
return attn_output.to(original_dtype), attn_weights, updated_past_key_values
|
||||||
|
|
||||||
|
|
||||||
def use_flash_attention(query):
|
|
||||||
bsz, q_len, _ = query.size()
|
|
||||||
# check whether ipex flash attention can be used
|
|
||||||
if bsz > 1:
|
|
||||||
# only use flash attention for batch_size = 1 now
|
|
||||||
# as flash attention doesn't support attn_mask in ipex 2.1,
|
|
||||||
# so it will cause output error for padded batch input
|
|
||||||
return False
|
|
||||||
if q_len == 1:
|
|
||||||
# now only use flash attention for first token
|
|
||||||
# as it seems have no performance benifit for rest token now
|
|
||||||
return False
|
|
||||||
if query.device.type != "xpu":
|
|
||||||
# ipex flash attention only support for xpu
|
|
||||||
return False
|
|
||||||
ipex_version = get_ipex_version()
|
|
||||||
if ipex_version <= "2.0.110+xpu":
|
|
||||||
# ipex flash attention is supported from ipex 2.1
|
|
||||||
return False
|
|
||||||
if not torch.xpu.has_xetla():
|
|
||||||
# ipex flash attention is only supported for xetla
|
|
||||||
# may update this later
|
|
||||||
return False
|
|
||||||
if query.dtype not in [torch.float32, torch.float16]:
|
|
||||||
# only use flash attention for fp32/fp16 input
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def use_esimd_sdp(q_len, head_dim, query_states):
|
|
||||||
if head_dim != 128:
|
|
||||||
# esimd_sdp only support head_dim = 128 now
|
|
||||||
return False
|
|
||||||
elif q_len != 1:
|
|
||||||
# esimd_sdp only support rest token now
|
|
||||||
return False
|
|
||||||
elif query_states.device.type != "xpu":
|
|
||||||
# esimd_sdp only support GPU now
|
|
||||||
return False
|
|
||||||
elif query_states.dtype != torch.float16:
|
|
||||||
# esimd_sdp only has optimization for FP16 now
|
|
||||||
return False
|
|
||||||
else:
|
|
||||||
device_name = torch.xpu.get_device_name(query_states.device.index)
|
|
||||||
if device_name.startswith("Intel(R) Arc(TM) A") or \
|
|
||||||
device_name.startswith("Intel(R) Data Center GPU Flex"):
|
|
||||||
import linear_fp16_esimd
|
|
||||||
if hasattr(linear_fp16_esimd, "sdp_forward"):
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
return False
|
|
||||||
else:
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def native_sdp(query, key, value, attention_mask,
|
def native_sdp(query, key, value, attention_mask,
|
||||||
bsz, q_len, kv_seq_len, head_dim, num_heads):
|
bsz, q_len, kv_seq_len, head_dim, num_heads):
|
||||||
attn_weights = torch.matmul(query,
|
attn_weights = torch.matmul(query,
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from bigdl.llm.utils.common import invalidInputError
|
from bigdl.llm.utils.common import invalidInputError
|
||||||
|
from bigdl.llm.transformers.utils import get_ipex_version
|
||||||
|
|
||||||
|
|
||||||
def init_kv_cache(batch_size, num_heads, head_dim, current_length, max_length, dtype, device):
|
def init_kv_cache(batch_size, num_heads, head_dim, current_length, max_length, dtype, device):
|
||||||
|
|
@ -119,3 +120,61 @@ def is_enough_kv_cache_room_4_31(past_key_value):
|
||||||
# to determinate if is enough kv cache room in transformers between 4.31 and 4.35
|
# to determinate if is enough kv cache room in transformers between 4.31 and 4.35
|
||||||
return past_key_value is not None and \
|
return past_key_value is not None and \
|
||||||
past_key_value[0].stride()[1] > past_key_value[0].size(2) * past_key_value[0].size(3)
|
past_key_value[0].stride()[1] > past_key_value[0].size(2) * past_key_value[0].size(3)
|
||||||
|
|
||||||
|
|
||||||
|
def use_flash_attention(query):
|
||||||
|
if query.dim() == 3:
|
||||||
|
bsz, q_len, _ = query.size()
|
||||||
|
elif query.dim() == 4:
|
||||||
|
bsz, _, q_len, _ = query.size()
|
||||||
|
# check whether ipex flash attention can be used
|
||||||
|
if bsz > 1:
|
||||||
|
# only use flash attention for batch_size = 1 now
|
||||||
|
# as flash attention doesn't support attn_mask in ipex 2.1,
|
||||||
|
# so it will cause output error for padded batch input
|
||||||
|
return False
|
||||||
|
if q_len == 1:
|
||||||
|
# now only use flash attention for first token
|
||||||
|
# as it seems have no performance benifit for rest token now
|
||||||
|
return False
|
||||||
|
if query.device.type != "xpu":
|
||||||
|
# ipex flash attention only support for xpu
|
||||||
|
return False
|
||||||
|
ipex_version = get_ipex_version()
|
||||||
|
if ipex_version <= "2.0.110+xpu":
|
||||||
|
# ipex flash attention is supported from ipex 2.1
|
||||||
|
return False
|
||||||
|
if not torch.xpu.has_xetla():
|
||||||
|
# ipex flash attention is only supported for xetla
|
||||||
|
# may update this later
|
||||||
|
return False
|
||||||
|
if query.dtype not in [torch.float32, torch.float16]:
|
||||||
|
# only use flash attention for fp32/fp16 input
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def use_esimd_sdp(q_len, head_dim, query_states):
|
||||||
|
if head_dim != 128:
|
||||||
|
# esimd_sdp only support head_dim = 128 now
|
||||||
|
return False
|
||||||
|
elif q_len != 1:
|
||||||
|
# esimd_sdp only support rest token now
|
||||||
|
return False
|
||||||
|
elif query_states.device.type != "xpu":
|
||||||
|
# esimd_sdp only support GPU now
|
||||||
|
return False
|
||||||
|
elif query_states.dtype != torch.float16:
|
||||||
|
# esimd_sdp only has optimization for FP16 now
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
device_name = torch.xpu.get_device_name(query_states.device.index)
|
||||||
|
if device_name.startswith("Intel(R) Arc(TM) A") or \
|
||||||
|
device_name.startswith("Intel(R) Data Center GPU Flex"):
|
||||||
|
import linear_fp16_esimd
|
||||||
|
if hasattr(linear_fp16_esimd, "sdp_forward"):
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue