LLM: fix chatglm3 sdp to support speculative decoding (#9900)
* fix chatglm3 * fix * update * meet code review * fix
This commit is contained in:
parent
9f34da7cdb
commit
8d7326ae03
1 changed files with 6 additions and 8 deletions
|
|
@ -373,11 +373,15 @@ def core_attn_forward_8eb45c(self, query_layer, key_layer, value_layer, attentio
|
||||||
key_layer,
|
key_layer,
|
||||||
value_layer,
|
value_layer,
|
||||||
is_causal=True)
|
is_causal=True)
|
||||||
elif attention_mask is None:
|
else:
|
||||||
head_dim = query_layer.size(-1)
|
head_dim = query_layer.size(-1)
|
||||||
attn = torch.matmul(query_layer,
|
attn = torch.matmul(query_layer,
|
||||||
key_layer.transpose(2, 3)) / math.sqrt(head_dim)
|
key_layer.transpose(2, 3)) / math.sqrt(head_dim)
|
||||||
if L == S:
|
if attention_mask is not None:
|
||||||
|
attention_mask = ~attention_mask
|
||||||
|
attention_mask.masked_fill_(attention_mask.logical_not(), float("-inf"))
|
||||||
|
attn += attention_mask
|
||||||
|
elif L == S:
|
||||||
# first token, need attention mask
|
# first token, need attention mask
|
||||||
attn_bias = torch.zeros(L, S, dtype=query_layer.dtype,
|
attn_bias = torch.zeros(L, S, dtype=query_layer.dtype,
|
||||||
device=query_layer.device)
|
device=query_layer.device)
|
||||||
|
|
@ -388,12 +392,6 @@ def core_attn_forward_8eb45c(self, query_layer, key_layer, value_layer, attentio
|
||||||
attn += attn_bias
|
attn += attn_bias
|
||||||
attn = torch.softmax(attn, -1)
|
attn = torch.softmax(attn, -1)
|
||||||
context_layer = torch.matmul(attn, value_layer)
|
context_layer = torch.matmul(attn, value_layer)
|
||||||
else:
|
|
||||||
attention_mask = ~attention_mask
|
|
||||||
context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer,
|
|
||||||
key_layer,
|
|
||||||
value_layer,
|
|
||||||
attention_mask)
|
|
||||||
context_layer = context_layer.permute(2, 0, 1, 3)
|
context_layer = context_layer.permute(2, 0, 1, 3)
|
||||||
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
|
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
|
||||||
context_layer = context_layer.reshape(*new_context_layer_shape)
|
context_layer = context_layer.reshape(*new_context_layer_shape)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue