diff --git a/python/llm/src/bigdl/llm/transformers/models/chatglm2.py b/python/llm/src/bigdl/llm/transformers/models/chatglm2.py index 92dfca03..bef388d8 100644 --- a/python/llm/src/bigdl/llm/transformers/models/chatglm2.py +++ b/python/llm/src/bigdl/llm/transformers/models/chatglm2.py @@ -369,8 +369,7 @@ def core_attn_forward_8eb45c(self, query_layer, key_layer, value_layer, attentio if pytorch_major_version >= 2 and (query_layer.device.type == 'xpu' or query_layer.size(0) > 1): query_layer = query_layer.permute(1, 2, 0, 3) L, S = query_layer.shape[2], key_layer.shape[2] - if attention_mask is None and (use_flash_attention(query_layer, key_layer) or - L == S and query_layer.device.type == "cpu"): + if attention_mask is None and L == S: context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer, @@ -380,19 +379,16 @@ def core_attn_forward_8eb45c(self, query_layer, key_layer, value_layer, attentio attn = torch.matmul(query_layer, key_layer.transpose(2, 3)) / math.sqrt(head_dim) if attention_mask is not None: - attention_mask = ~attention_mask - attention_mask.masked_fill_(attention_mask.logical_not(), float("-inf")) - attn += attention_mask - elif L == S: - # first token, need attention mask - attn_bias = torch.zeros(L, S, dtype=query_layer.dtype, + attn_bias = torch.zeros(attention_mask.shape, dtype=query_layer.dtype, device=query_layer.device) - temp_mask = torch.ones(L, S, dtype=torch.bool, - device=query_layer.device).tril(diagonal=0) - attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) - attn_bias.to(query_layer.dtype) + attention_mask = ~attention_mask + if attention_mask.dtype == torch.bool: + attn_bias.masked_fill_(attention_mask.logical_not(), float("-inf")) + else: + attn_bias += attention_mask attn += attn_bias - attn = torch.softmax(attn, -1) + attn = F.softmax(attn, dim=-1, + dtype=torch.float32).to(value_layer.dtype) context_layer = torch.matmul(attn, value_layer) context_layer = context_layer.permute(2, 0, 1, 3) new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)