LLM: unify baichuan2-13b alibi mask dtype with model dtype. (#11107)
* LLM: unify alibi mask dtype. * fix comments.
This commit is contained in:
parent
0a06a6e1d4
commit
011b9faa5c
1 changed files with 3 additions and 3 deletions
|
|
@ -259,15 +259,15 @@ def _buffered_future_mask(tensor, maxpos, alibi, attn_heads):
|
|||
|
||||
|
||||
def baichuan_13b_gen_alibi_mask(tensor, n_head, max_pos):
|
||||
# May use fp16 for alibi mask to further reduce memory
|
||||
slopes = torch.Tensor(_get_interleave(n_head)) # .half()
|
||||
slopes = torch.Tensor(_get_interleave(n_head)).to(tensor.dtype)
|
||||
position_point = torch.arange(max_pos) - max_pos + 1
|
||||
position_point = position_point.unsqueeze(0).unsqueeze(0).expand(n_head, -1, -1)
|
||||
diag = torch.diag(position_point[0])
|
||||
position_point = position_point - diag.unsqueeze(0).unsqueeze(0).transpose(-1, -2)
|
||||
alibi = slopes.unsqueeze(1).unsqueeze(1) * position_point
|
||||
alibi = alibi.view(n_head, 1, max_pos)
|
||||
alibi_mask = torch.triu(_fill_with_neg_inf(torch.zeros([max_pos, max_pos])), 1) # .half()
|
||||
alibi_mask = torch.triu(
|
||||
_fill_with_neg_inf(torch.zeros([max_pos, max_pos])), 1).to(tensor.dtype)
|
||||
alibi_mask = alibi_mask.unsqueeze(0) + alibi
|
||||
if tensor.device.type == "xpu":
|
||||
alibi_mask = alibi_mask.to(tensor.device)
|
||||
|
|
|
|||
Loading…
Reference in a new issue