LLM: add esimd sdp support for chatglm3 (#10205)

* add esimd sdp support

* fix style
This commit is contained in:
Ruonan Wang 2024-02-22 13:37:16 +08:00 committed by GitHub
parent 7cbc2429a6
commit 34ee1aa91f

View file

@ -25,6 +25,7 @@ from transformers.modeling_outputs import BaseModelOutputWithPast
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
from bigdl.llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_cache, \ from bigdl.llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_cache, \
restore_fp8_kv_cache, use_quantize_kv_cache restore_fp8_kv_cache, use_quantize_kv_cache
from bigdl.llm.transformers.models.utils import use_esimd_sdp
KV_CACHE_ALLOC_BLOCK_LENGTH = 256 KV_CACHE_ALLOC_BLOCK_LENGTH = 256
@ -515,23 +516,31 @@ def core_attn_forward_8eb45c(query_layer, key_layer, value_layer, attention_mask
context_layer = F.scaled_dot_product_attention(query_layer.to(key_layer.dtype), context_layer = F.scaled_dot_product_attention(query_layer.to(key_layer.dtype),
key_layer, key_layer,
value_layer, value_layer,
is_causal=True) is_causal=True).to(key_layer.dtype)
else: else:
head_dim = query_layer.size(-1) if use_esimd_sdp(query_layer.shape[2], key_layer.shape[2],
attn = torch.matmul(query_layer.to(key_layer.dtype), query_layer.shape[-1], query_layer):
key_layer.transpose(2, 3)) / math.sqrt(head_dim) import linear_fp16_esimd
if attention_mask is not None: attn_output = linear_fp16_esimd.sdp_forward(query_layer,
attn_bias = torch.zeros(attention_mask.shape, dtype=query_layer.dtype, key_layer,
device=query_layer.device) value_layer)
attention_mask = ~attention_mask context_layer = attn_output.view(query_layer.shape)
if attention_mask.dtype == torch.bool: else:
attn_bias.masked_fill_(attention_mask.logical_not(), float("-inf")) head_dim = query_layer.size(-1)
else: attn = torch.matmul(query_layer.to(key_layer.dtype),
attn_bias += attention_mask key_layer.transpose(2, 3)) / math.sqrt(head_dim)
attn += attn_bias if attention_mask is not None:
attn = F.softmax(attn, dim=-1, attn_bias = torch.zeros(attention_mask.shape, dtype=query_layer.dtype,
dtype=torch.float32).to(value_layer.dtype) device=query_layer.device)
context_layer = torch.matmul(attn, value_layer) attention_mask = ~attention_mask
if attention_mask.dtype == torch.bool:
attn_bias.masked_fill_(attention_mask.logical_not(), float("-inf"))
else:
attn_bias += attention_mask
attn += attn_bias
attn = F.softmax(attn, dim=-1,
dtype=torch.float32).to(value_layer.dtype)
context_layer = torch.matmul(attn, value_layer)
context_layer = context_layer.permute(2, 0, 1, 3) context_layer = context_layer.permute(2, 0, 1, 3)
new_context_layer_shape = context_layer.size()[:-2] + (-1,) new_context_layer_shape = context_layer.size()[:-2] + (-1,)
context_layer = context_layer.reshape(*new_context_layer_shape) context_layer = context_layer.reshape(*new_context_layer_shape)