LLM: fix accuracy issue of chatglm3 (#9830)
* add attn mask for first token * fix * fix * change attn calculation * fix * fix * fix style * fix style
This commit is contained in:
parent
3147ebe63d
commit
5df31db773
1 changed files with 15 additions and 3 deletions
|
|
@ -366,14 +366,26 @@ def core_attn_forward_8eb45c(self, query_layer, key_layer, value_layer, attentio
|
||||||
pytorch_major_version = int(torch.__version__.split('.')[0])
|
pytorch_major_version = int(torch.__version__.split('.')[0])
|
||||||
if pytorch_major_version >= 2 and (query_layer.device.type == 'xpu' or query_layer.size(0) > 1):
|
if pytorch_major_version >= 2 and (query_layer.device.type == 'xpu' or query_layer.size(0) > 1):
|
||||||
query_layer = query_layer.permute(1, 2, 0, 3)
|
query_layer = query_layer.permute(1, 2, 0, 3)
|
||||||
if attention_mask is None and use_flash_attention(query_layer):
|
L, S = query_layer.shape[2], key_layer.shape[2]
|
||||||
|
if attention_mask is None and (use_flash_attention(query_layer) or
|
||||||
|
L == S and query_layer.device.type == "cpu"):
|
||||||
context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer,
|
context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer,
|
||||||
key_layer,
|
key_layer,
|
||||||
value_layer,
|
value_layer,
|
||||||
is_causal=True)
|
is_causal=True)
|
||||||
elif attention_mask is None:
|
elif attention_mask is None:
|
||||||
scaling_factor = 1 / math.sqrt(query_layer.size(-1))
|
head_dim = query_layer.size(-1)
|
||||||
attn = torch.matmul(query_layer * scaling_factor, key_layer.transpose(-2, -1))
|
attn = torch.matmul(query_layer,
|
||||||
|
key_layer.transpose(2, 3)) / math.sqrt(head_dim)
|
||||||
|
if L == S:
|
||||||
|
# first token, need attention mask
|
||||||
|
attn_bias = torch.zeros(L, S, dtype=query_layer.dtype,
|
||||||
|
device=query_layer.device)
|
||||||
|
temp_mask = torch.ones(L, S, dtype=torch.bool,
|
||||||
|
device=query_layer.device).tril(diagonal=0)
|
||||||
|
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
|
||||||
|
attn_bias.to(query_layer.dtype)
|
||||||
|
attn += attn_bias
|
||||||
attn = torch.softmax(attn, -1)
|
attn = torch.softmax(attn, -1)
|
||||||
context_layer = torch.matmul(attn, value_layer)
|
context_layer = torch.matmul(attn, value_layer)
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue