LLM: unify baichuan2-13b alibi mask dtype with model dtype. (#11107)

* LLM: unify alibi mask dtype.

* fix comments.
This commit is contained in:
Cengguang Zhang 2024-05-24 10:27:53 +08:00 committed by GitHub
parent 0a06a6e1d4
commit 011b9faa5c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -259,15 +259,15 @@ def _buffered_future_mask(tensor, maxpos, alibi, attn_heads):
def baichuan_13b_gen_alibi_mask(tensor, n_head, max_pos):
# May use fp16 for alibi mask to further reduce memory
slopes = torch.Tensor(_get_interleave(n_head)) # .half()
slopes = torch.Tensor(_get_interleave(n_head)).to(tensor.dtype)
position_point = torch.arange(max_pos) - max_pos + 1
position_point = position_point.unsqueeze(0).unsqueeze(0).expand(n_head, -1, -1)
diag = torch.diag(position_point[0])
position_point = position_point - diag.unsqueeze(0).unsqueeze(0).transpose(-1, -2)
alibi = slopes.unsqueeze(1).unsqueeze(1) * position_point
alibi = alibi.view(n_head, 1, max_pos)
alibi_mask = torch.triu(_fill_with_neg_inf(torch.zeros([max_pos, max_pos])), 1) # .half()
alibi_mask = torch.triu(
_fill_with_neg_inf(torch.zeros([max_pos, max_pos])), 1).to(tensor.dtype)
alibi_mask = alibi_mask.unsqueeze(0) + alibi
if tensor.device.type == "xpu":
alibi_mask = alibi_mask.to(tensor.device)