diff --git a/python/llm/src/ipex_llm/transformers/models/chatglm.py b/python/llm/src/ipex_llm/transformers/models/chatglm.py index b17ff131..9b9888f4 100644 --- a/python/llm/src/ipex_llm/transformers/models/chatglm.py +++ b/python/llm/src/ipex_llm/transformers/models/chatglm.py @@ -104,7 +104,7 @@ def attention_fn( present = None pytorch_major_version = int(torch.__version__.split('.')[0]) - if query_layer.size(0) > 1 and pytorch_major_version >= 2: + if pytorch_major_version >= 2: query_layer = query_layer.permute(1, 2, 0, 3) if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]: @@ -129,18 +129,35 @@ def attention_fn( attention_mask, is_causal=True) else: + # attention_mask is not None only when past_key_value is not None and q_len > 1 if attention_mask is not None: + attn_bias = torch.zeros(attention_mask.shape, dtype=query_layer.dtype, + device=query_layer.device) attention_mask = ~attention_mask - attention_mask = attention_mask.masked_fill(~attention_mask, -float('inf'), ) + if attention_mask.dtype == torch.bool: + attn_bias.masked_fill_(attention_mask.logical_not(), float("-inf")) + else: + attn_bias += attention_mask + else: + attn_bias = None if torch.is_autocast_cpu_enabled(): query_layer = query_layer.to(torch.get_autocast_cpu_dtype()) key_layer = key_layer.to(torch.get_autocast_cpu_dtype()) value_layer = value_layer.to(torch.get_autocast_cpu_dtype()) attention_mask = attention_mask.to(torch.get_autocast_cpu_dtype()) - context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, - key_layer, - value_layer, - attention_mask) + context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, + key_layer, + value_layer, + attention_mask) + else: + head_dim = query_layer.size(-1) + attn = torch.matmul(query_layer.to(key_layer.dtype), + key_layer.transpose(2, 3)) / math.sqrt(head_dim) + if attn_bias is not None: + attn += attn_bias + attn = F.softmax(attn, dim=-1, + dtype=torch.float32).to(value_layer.dtype) + context_layer = torch.matmul(attn, value_layer) context_layer = context_layer.permute(2, 0, 1, 3) new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) context_layer = context_layer.reshape(*new_context_layer_shape)