diff --git a/python/llm/src/bigdl/llm/transformers/models/baichuan2.py b/python/llm/src/bigdl/llm/transformers/models/baichuan2.py index 8cd48ec4..21bf40db 100644 --- a/python/llm/src/bigdl/llm/transformers/models/baichuan2.py +++ b/python/llm/src/bigdl/llm/transformers/models/baichuan2.py @@ -284,7 +284,16 @@ def baichuan_attention_forward_13b( attention_mask = attention_mask[:, :, -1:, :] else: attention_mask = attention_mask[:, -1:, :] - attn_weights = attn_weights + attention_mask + if attention_mask.shape[-2] == attn_weights.shape[-2]: + attn_weights = attn_weights + attention_mask + else: + # support for Baichuan/Baichuan2 13B Chat running speculative decoding + # split attention mask on dim -2 + split_sizes = [attention_mask.shape[-2] - attn_weights.shape[-2], + attn_weights.shape[-2]] + # the last chunk of splited is the new attention mask + attention_mask = attention_mask.split(split_sizes, dim=-2)[-1] + attn_weights = attn_weights + attention_mask attn_weights = torch.max( attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min) )