LLM: fix chatglm3 issue (#9820)

* fix chatglm3 issue

* small update
This commit is contained in:
Ruonan Wang 2024-01-03 16:15:55 +08:00 committed by GitHub
parent a54cd767b1
commit 20e9742fa0
3 changed files with 62 additions and 57 deletions

View file

@ -23,6 +23,7 @@ from typing import Optional, Tuple, List
import torch.nn.functional as F
from transformers.modeling_outputs import BaseModelOutputWithPast
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
from bigdl.llm.transformers.models.utils import use_flash_attention
from bigdl.llm.transformers.models.llama import get_ipex_version
@ -365,11 +366,10 @@ def core_attn_forward_8eb45c(self, query_layer, key_layer, value_layer, attentio
pytorch_major_version = int(torch.__version__.split('.')[0])
if pytorch_major_version >= 2 and (query_layer.device.type == 'xpu' or query_layer.size(0) > 1):
query_layer = query_layer.permute(1, 2, 0, 3)
if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
if attention_mask is None and use_flash_attention(query_layer):
context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer,
key_layer,
value_layer,
attention_mask,
is_causal=True)
elif attention_mask is None:
scaling_factor = 1 / math.sqrt(query_layer.size(-1))

View file

@ -42,6 +42,7 @@ from bigdl.llm.utils.common import invalidInputError
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_31, apply_rotary_pos_emb
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
from bigdl.llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
from transformers.modeling_outputs import BaseModelOutputWithPast
from bigdl.llm.transformers.low_bit_linear import SYM_INT4
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
@ -508,61 +509,6 @@ def llama_attention_selective_batching_forward_4_31(
return attn_output.to(original_dtype), attn_weights, updated_past_key_values
def use_flash_attention(query):
bsz, q_len, _ = query.size()
# check whether ipex flash attention can be used
if bsz > 1:
# only use flash attention for batch_size = 1 now
# as flash attention doesn't support attn_mask in ipex 2.1,
# so it will cause output error for padded batch input
return False
if q_len == 1:
# now only use flash attention for first token
# as it seems have no performance benifit for rest token now
return False
if query.device.type != "xpu":
# ipex flash attention only support for xpu
return False
ipex_version = get_ipex_version()
if ipex_version <= "2.0.110+xpu":
# ipex flash attention is supported from ipex 2.1
return False
if not torch.xpu.has_xetla():
# ipex flash attention is only supported for xetla
# may update this later
return False
if query.dtype not in [torch.float32, torch.float16]:
# only use flash attention for fp32/fp16 input
return False
return True
def use_esimd_sdp(q_len, head_dim, query_states):
if head_dim != 128:
# esimd_sdp only support head_dim = 128 now
return False
elif q_len != 1:
# esimd_sdp only support rest token now
return False
elif query_states.device.type != "xpu":
# esimd_sdp only support GPU now
return False
elif query_states.dtype != torch.float16:
# esimd_sdp only has optimization for FP16 now
return False
else:
device_name = torch.xpu.get_device_name(query_states.device.index)
if device_name.startswith("Intel(R) Arc(TM) A") or \
device_name.startswith("Intel(R) Data Center GPU Flex"):
import linear_fp16_esimd
if hasattr(linear_fp16_esimd, "sdp_forward"):
return True
else:
return False
else:
return False
def native_sdp(query, key, value, attention_mask,
bsz, q_len, kv_seq_len, head_dim, num_heads):
attn_weights = torch.matmul(query,

View file

@ -16,6 +16,7 @@
import torch
from bigdl.llm.utils.common import invalidInputError
from bigdl.llm.transformers.utils import get_ipex_version
def init_kv_cache(batch_size, num_heads, head_dim, current_length, max_length, dtype, device):
@ -119,3 +120,61 @@ def is_enough_kv_cache_room_4_31(past_key_value):
# to determinate if is enough kv cache room in transformers between 4.31 and 4.35
return past_key_value is not None and \
past_key_value[0].stride()[1] > past_key_value[0].size(2) * past_key_value[0].size(3)
def use_flash_attention(query):
if query.dim() == 3:
bsz, q_len, _ = query.size()
elif query.dim() == 4:
bsz, _, q_len, _ = query.size()
# check whether ipex flash attention can be used
if bsz > 1:
# only use flash attention for batch_size = 1 now
# as flash attention doesn't support attn_mask in ipex 2.1,
# so it will cause output error for padded batch input
return False
if q_len == 1:
# now only use flash attention for first token
# as it seems have no performance benifit for rest token now
return False
if query.device.type != "xpu":
# ipex flash attention only support for xpu
return False
ipex_version = get_ipex_version()
if ipex_version <= "2.0.110+xpu":
# ipex flash attention is supported from ipex 2.1
return False
if not torch.xpu.has_xetla():
# ipex flash attention is only supported for xetla
# may update this later
return False
if query.dtype not in [torch.float32, torch.float16]:
# only use flash attention for fp32/fp16 input
return False
return True
def use_esimd_sdp(q_len, head_dim, query_states):
if head_dim != 128:
# esimd_sdp only support head_dim = 128 now
return False
elif q_len != 1:
# esimd_sdp only support rest token now
return False
elif query_states.device.type != "xpu":
# esimd_sdp only support GPU now
return False
elif query_states.dtype != torch.float16:
# esimd_sdp only has optimization for FP16 now
return False
else:
device_name = torch.xpu.get_device_name(query_states.device.index)
if device_name.startswith("Intel(R) Arc(TM) A") or \
device_name.startswith("Intel(R) Data Center GPU Flex"):
import linear_fp16_esimd
if hasattr(linear_fp16_esimd, "sdp_forward"):
return True
else:
return False
else:
return False