support for Baichuan/Baichuan2 13B Chat running speculative decoding (#9921)

* support for Baichuan/Baichuan2 13B Chat running speculative decoding

* fix stype
This commit is contained in:
Heyang Sun 2024-01-22 09:11:44 +08:00 committed by GitHub
parent 97f0cd8975
commit fb91c97fe8

View file

@ -284,7 +284,16 @@ def baichuan_attention_forward_13b(
attention_mask = attention_mask[:, :, -1:, :]
else:
attention_mask = attention_mask[:, -1:, :]
attn_weights = attn_weights + attention_mask
if attention_mask.shape[-2] == attn_weights.shape[-2]:
attn_weights = attn_weights + attention_mask
else:
# support for Baichuan/Baichuan2 13B Chat running speculative decoding
# split attention mask on dim -2
split_sizes = [attention_mask.shape[-2] - attn_weights.shape[-2],
attn_weights.shape[-2]]
# the last chunk of splited is the new attention mask
attention_mask = attention_mask.split(split_sizes, dim=-2)[-1]
attn_weights = attn_weights + attention_mask
attn_weights = torch.max(
attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
)