fix baichuan2 13b fp16 (#11071)

This commit is contained in:
Yishuo Wang 2024-05-20 11:21:20 +08:00 committed by GitHub
parent 7170dd9192
commit 4e97047d70
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -213,6 +213,7 @@ def baichuan_attention_forward_13b(
attn_weights = attn_weights / math.sqrt(self.head_dim)
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = attn_weights.to(query_states.dtype)
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
if use_quantize_kv and q_len == 1:
import linear_q4_0