diff --git a/python/llm/src/bigdl/llm/transformers/models/baichuan2.py b/python/llm/src/bigdl/llm/transformers/models/baichuan2.py index bbb48eed..e6a2ab49 100644 --- a/python/llm/src/bigdl/llm/transformers/models/baichuan2.py +++ b/python/llm/src/bigdl/llm/transformers/models/baichuan2.py @@ -167,10 +167,17 @@ def baichuan_attention_forward_7b( query_states, key_states, value_states, attn_bias=xops.LowerTriangularMask() ) else: - with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, - enable_mem_efficient=True): - attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, - attn_mask=attention_mask) + if attention_mask is not None: + if attention_mask.dtype == torch.bool: + attention_mask.masked_fill_(attention_mask.logical_not(), float("-inf")) + + scaling_factor = 1 / math.sqrt(query_states.size(-1)) + attn_output = torch.matmul(query_states * scaling_factor, key_states.transpose(-2, -1)) + if attention_mask is not None: + attn_output += attention_mask + attn_output = torch.softmax(attn_output, -1) + attn_output = torch.matmul(attn_output, value_states) + attn_output = attn_output.transpose(1, 2) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) diff --git a/python/llm/src/bigdl/llm/transformers/models/chatglm2.py b/python/llm/src/bigdl/llm/transformers/models/chatglm2.py index 8276b967..898a673a 100644 --- a/python/llm/src/bigdl/llm/transformers/models/chatglm2.py +++ b/python/llm/src/bigdl/llm/transformers/models/chatglm2.py @@ -17,6 +17,7 @@ # https://huggingface.co/THUDM/chatglm2-6b/blob/8eb45c842594b8473f291d0f94e7bbe86ffc67d8/modeling_chatglm.py # +import math import torch from typing import Optional, Tuple, List import torch.nn.functional as F @@ -370,9 +371,13 @@ def core_attn_forward_8eb45c(self, query_layer, key_layer, value_layer, attentio value_layer, attention_mask, is_causal=True) + elif attention_mask is None: + scaling_factor = 1 / math.sqrt(query_layer.size(-1)) + attn = torch.matmul(query_layer * scaling_factor, key_layer.transpose(-2, -1)) + attn = torch.softmax(attn, -1) + context_layer = torch.matmul(attn, value_layer) else: - if attention_mask is not None: - attention_mask = ~attention_mask + attention_mask = ~attention_mask context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,