fix chatglm run error (#11045)

* fix chatglm

* update

* fix style
This commit is contained in:
Xin Qiu 2024-05-16 15:39:18 +08:00 committed by GitHub
parent 8cae897643
commit 6be70283b7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -104,7 +104,7 @@ def attention_fn(
present = None
pytorch_major_version = int(torch.__version__.split('.')[0])
if query_layer.size(0) > 1 and pytorch_major_version >= 2:
if pytorch_major_version >= 2:
query_layer = query_layer.permute(1, 2, 0, 3)
if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
@ -129,18 +129,35 @@ def attention_fn(
attention_mask,
is_causal=True)
else:
# attention_mask is not None only when past_key_value is not None and q_len > 1
if attention_mask is not None:
attn_bias = torch.zeros(attention_mask.shape, dtype=query_layer.dtype,
device=query_layer.device)
attention_mask = ~attention_mask
attention_mask = attention_mask.masked_fill(~attention_mask, -float('inf'), )
if attention_mask.dtype == torch.bool:
attn_bias.masked_fill_(attention_mask.logical_not(), float("-inf"))
else:
attn_bias += attention_mask
else:
attn_bias = None
if torch.is_autocast_cpu_enabled():
query_layer = query_layer.to(torch.get_autocast_cpu_dtype())
key_layer = key_layer.to(torch.get_autocast_cpu_dtype())
value_layer = value_layer.to(torch.get_autocast_cpu_dtype())
attention_mask = attention_mask.to(torch.get_autocast_cpu_dtype())
context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer,
key_layer,
value_layer,
attention_mask)
context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer,
key_layer,
value_layer,
attention_mask)
else:
head_dim = query_layer.size(-1)
attn = torch.matmul(query_layer.to(key_layer.dtype),
key_layer.transpose(2, 3)) / math.sqrt(head_dim)
if attn_bias is not None:
attn += attn_bias
attn = F.softmax(attn, dim=-1,
dtype=torch.float32).to(value_layer.dtype)
context_layer = torch.matmul(attn, value_layer)
context_layer = context_layer.permute(2, 0, 1, 3)
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
context_layer = context_layer.reshape(*new_context_layer_shape)