diff --git a/python/llm/src/ipex_llm/transformers/models/baichuan2.py b/python/llm/src/ipex_llm/transformers/models/baichuan2.py index 74414ed8..e496e68b 100644 --- a/python/llm/src/ipex_llm/transformers/models/baichuan2.py +++ b/python/llm/src/ipex_llm/transformers/models/baichuan2.py @@ -259,15 +259,15 @@ def _buffered_future_mask(tensor, maxpos, alibi, attn_heads): def baichuan_13b_gen_alibi_mask(tensor, n_head, max_pos): - # May use fp16 for alibi mask to further reduce memory - slopes = torch.Tensor(_get_interleave(n_head)) # .half() + slopes = torch.Tensor(_get_interleave(n_head)).to(tensor.dtype) position_point = torch.arange(max_pos) - max_pos + 1 position_point = position_point.unsqueeze(0).unsqueeze(0).expand(n_head, -1, -1) diag = torch.diag(position_point[0]) position_point = position_point - diag.unsqueeze(0).unsqueeze(0).transpose(-1, -2) alibi = slopes.unsqueeze(1).unsqueeze(1) * position_point alibi = alibi.view(n_head, 1, max_pos) - alibi_mask = torch.triu(_fill_with_neg_inf(torch.zeros([max_pos, max_pos])), 1) # .half() + alibi_mask = torch.triu( + _fill_with_neg_inf(torch.zeros([max_pos, max_pos])), 1).to(tensor.dtype) alibi_mask = alibi_mask.unsqueeze(0) + alibi if tensor.device.type == "xpu": alibi_mask = alibi_mask.to(tensor.device)