support for Baichuan/Baichuan2 13B Chat running speculative decoding (#9921)
* support for Baichuan/Baichuan2 13B Chat running speculative decoding * fix stype
This commit is contained in:
parent
97f0cd8975
commit
fb91c97fe8
1 changed files with 10 additions and 1 deletions
|
|
@ -284,7 +284,16 @@ def baichuan_attention_forward_13b(
|
|||
attention_mask = attention_mask[:, :, -1:, :]
|
||||
else:
|
||||
attention_mask = attention_mask[:, -1:, :]
|
||||
attn_weights = attn_weights + attention_mask
|
||||
if attention_mask.shape[-2] == attn_weights.shape[-2]:
|
||||
attn_weights = attn_weights + attention_mask
|
||||
else:
|
||||
# support for Baichuan/Baichuan2 13B Chat running speculative decoding
|
||||
# split attention mask on dim -2
|
||||
split_sizes = [attention_mask.shape[-2] - attn_weights.shape[-2],
|
||||
attn_weights.shape[-2]]
|
||||
# the last chunk of splited is the new attention mask
|
||||
attention_mask = attention_mask.split(split_sizes, dim=-2)[-1]
|
||||
attn_weights = attn_weights + attention_mask
|
||||
attn_weights = torch.max(
|
||||
attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in a new issue