From 34ee1aa91f7b4a6f3cc2e6bb86ca8277dea5adcd Mon Sep 17 00:00:00 2001 From: Ruonan Wang Date: Thu, 22 Feb 2024 13:37:16 +0800 Subject: [PATCH] LLM: add esimd sdp support for chatglm3 (#10205) * add esimd sdp support * fix style --- .../bigdl/llm/transformers/models/chatglm2.py | 41 +++++++++++-------- 1 file changed, 25 insertions(+), 16 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/models/chatglm2.py b/python/llm/src/bigdl/llm/transformers/models/chatglm2.py index 1986b6ce..7ddd80eb 100644 --- a/python/llm/src/bigdl/llm/transformers/models/chatglm2.py +++ b/python/llm/src/bigdl/llm/transformers/models/chatglm2.py @@ -25,6 +25,7 @@ from transformers.modeling_outputs import BaseModelOutputWithPast from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache from bigdl.llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_cache, \ restore_fp8_kv_cache, use_quantize_kv_cache +from bigdl.llm.transformers.models.utils import use_esimd_sdp KV_CACHE_ALLOC_BLOCK_LENGTH = 256 @@ -515,23 +516,31 @@ def core_attn_forward_8eb45c(query_layer, key_layer, value_layer, attention_mask context_layer = F.scaled_dot_product_attention(query_layer.to(key_layer.dtype), key_layer, value_layer, - is_causal=True) + is_causal=True).to(key_layer.dtype) else: - head_dim = query_layer.size(-1) - attn = torch.matmul(query_layer.to(key_layer.dtype), - key_layer.transpose(2, 3)) / math.sqrt(head_dim) - if attention_mask is not None: - attn_bias = torch.zeros(attention_mask.shape, dtype=query_layer.dtype, - device=query_layer.device) - attention_mask = ~attention_mask - if attention_mask.dtype == torch.bool: - attn_bias.masked_fill_(attention_mask.logical_not(), float("-inf")) - else: - attn_bias += attention_mask - attn += attn_bias - attn = F.softmax(attn, dim=-1, - dtype=torch.float32).to(value_layer.dtype) - context_layer = torch.matmul(attn, value_layer) + if use_esimd_sdp(query_layer.shape[2], key_layer.shape[2], + query_layer.shape[-1], query_layer): + import linear_fp16_esimd + attn_output = linear_fp16_esimd.sdp_forward(query_layer, + key_layer, + value_layer) + context_layer = attn_output.view(query_layer.shape) + else: + head_dim = query_layer.size(-1) + attn = torch.matmul(query_layer.to(key_layer.dtype), + key_layer.transpose(2, 3)) / math.sqrt(head_dim) + if attention_mask is not None: + attn_bias = torch.zeros(attention_mask.shape, dtype=query_layer.dtype, + device=query_layer.device) + attention_mask = ~attention_mask + if attention_mask.dtype == torch.bool: + attn_bias.masked_fill_(attention_mask.logical_not(), float("-inf")) + else: + attn_bias += attention_mask + attn += attn_bias + attn = F.softmax(attn, dim=-1, + dtype=torch.float32).to(value_layer.dtype) + context_layer = torch.matmul(attn, value_layer) context_layer = context_layer.permute(2, 0, 1, 3) new_context_layer_shape = context_layer.size()[:-2] + (-1,) context_layer = context_layer.reshape(*new_context_layer_shape)