LLM: unify baichuan2-13b alibi mask dtype with model dtype. (#11107)
* LLM: unify alibi mask dtype. * fix comments.
This commit is contained in:
		
							parent
							
								
									0a06a6e1d4
								
							
						
					
					
						commit
						011b9faa5c
					
				
					 1 changed files with 3 additions and 3 deletions
				
			
		| 
						 | 
					@ -259,15 +259,15 @@ def _buffered_future_mask(tensor, maxpos, alibi, attn_heads):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def baichuan_13b_gen_alibi_mask(tensor, n_head, max_pos):
 | 
					def baichuan_13b_gen_alibi_mask(tensor, n_head, max_pos):
 | 
				
			||||||
    # May use fp16 for alibi mask to further reduce memory
 | 
					    slopes = torch.Tensor(_get_interleave(n_head)).to(tensor.dtype)
 | 
				
			||||||
    slopes = torch.Tensor(_get_interleave(n_head))  # .half()
 | 
					 | 
				
			||||||
    position_point = torch.arange(max_pos) - max_pos + 1
 | 
					    position_point = torch.arange(max_pos) - max_pos + 1
 | 
				
			||||||
    position_point = position_point.unsqueeze(0).unsqueeze(0).expand(n_head, -1, -1)
 | 
					    position_point = position_point.unsqueeze(0).unsqueeze(0).expand(n_head, -1, -1)
 | 
				
			||||||
    diag = torch.diag(position_point[0])
 | 
					    diag = torch.diag(position_point[0])
 | 
				
			||||||
    position_point = position_point - diag.unsqueeze(0).unsqueeze(0).transpose(-1, -2)
 | 
					    position_point = position_point - diag.unsqueeze(0).unsqueeze(0).transpose(-1, -2)
 | 
				
			||||||
    alibi = slopes.unsqueeze(1).unsqueeze(1) * position_point
 | 
					    alibi = slopes.unsqueeze(1).unsqueeze(1) * position_point
 | 
				
			||||||
    alibi = alibi.view(n_head, 1, max_pos)
 | 
					    alibi = alibi.view(n_head, 1, max_pos)
 | 
				
			||||||
    alibi_mask = torch.triu(_fill_with_neg_inf(torch.zeros([max_pos, max_pos])), 1)  # .half()
 | 
					    alibi_mask = torch.triu(
 | 
				
			||||||
 | 
					        _fill_with_neg_inf(torch.zeros([max_pos, max_pos])), 1).to(tensor.dtype)
 | 
				
			||||||
    alibi_mask = alibi_mask.unsqueeze(0) + alibi
 | 
					    alibi_mask = alibi_mask.unsqueeze(0) + alibi
 | 
				
			||||||
    if tensor.device.type == "xpu":
 | 
					    if tensor.device.type == "xpu":
 | 
				
			||||||
        alibi_mask = alibi_mask.to(tensor.device)
 | 
					        alibi_mask = alibi_mask.to(tensor.device)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue