diff --git a/python/llm/src/bigdl/llm/transformers/models/chatglm2.py b/python/llm/src/bigdl/llm/transformers/models/chatglm2.py index 4118f6bd..f410b8fc 100644 --- a/python/llm/src/bigdl/llm/transformers/models/chatglm2.py +++ b/python/llm/src/bigdl/llm/transformers/models/chatglm2.py @@ -367,17 +367,17 @@ def chatglm2_attention_forward_8eb45c( def core_attn_forward_8eb45c(self, query_layer, key_layer, value_layer, attention_mask): pytorch_major_version = int(torch.__version__.split('.')[0]) - if pytorch_major_version >= 2 and (query_layer.device.type == 'xpu' or query_layer.size(0) > 1): + if pytorch_major_version >= 2: query_layer = query_layer.permute(1, 2, 0, 3) L, S = query_layer.shape[2], key_layer.shape[2] if attention_mask is None and L == S: - context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, - key_layer, - value_layer, - is_causal=True) + context_layer = F.scaled_dot_product_attention(query_layer.to(key_layer.dtype), + key_layer, + value_layer, + is_causal=True) else: head_dim = query_layer.size(-1) - attn = torch.matmul(query_layer, + attn = torch.matmul(query_layer.to(key_layer.dtype), key_layer.transpose(2, 3)) / math.sqrt(head_dim) if attention_mask is not None: attn_bias = torch.zeros(attention_mask.shape, dtype=query_layer.dtype,