diff --git a/python/llm/src/bigdl/llm/transformers/models/chatglm2.py b/python/llm/src/bigdl/llm/transformers/models/chatglm2.py index b4f5604c..08ee571f 100644 --- a/python/llm/src/bigdl/llm/transformers/models/chatglm2.py +++ b/python/llm/src/bigdl/llm/transformers/models/chatglm2.py @@ -366,14 +366,26 @@ def core_attn_forward_8eb45c(self, query_layer, key_layer, value_layer, attentio pytorch_major_version = int(torch.__version__.split('.')[0]) if pytorch_major_version >= 2 and (query_layer.device.type == 'xpu' or query_layer.size(0) > 1): query_layer = query_layer.permute(1, 2, 0, 3) - if attention_mask is None and use_flash_attention(query_layer): + L, S = query_layer.shape[2], key_layer.shape[2] + if attention_mask is None and (use_flash_attention(query_layer) or + L == S and query_layer.device.type == "cpu"): context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer, is_causal=True) elif attention_mask is None: - scaling_factor = 1 / math.sqrt(query_layer.size(-1)) - attn = torch.matmul(query_layer * scaling_factor, key_layer.transpose(-2, -1)) + head_dim = query_layer.size(-1) + attn = torch.matmul(query_layer, + key_layer.transpose(2, 3)) / math.sqrt(head_dim) + if L == S: + # first token, need attention mask + attn_bias = torch.zeros(L, S, dtype=query_layer.dtype, + device=query_layer.device) + temp_mask = torch.ones(L, S, dtype=torch.bool, + device=query_layer.device).tril(diagonal=0) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + attn_bias.to(query_layer.dtype) + attn += attn_bias attn = torch.softmax(attn, -1) context_layer = torch.matmul(attn, value_layer) else: