fix baichuan2 13b 2k input (#10267)
This commit is contained in:
parent
7244fd1ba5
commit
cccb02dad1
1 changed files with 34 additions and 57 deletions
|
|
@ -242,36 +242,12 @@ def baichuan_attention_forward_13b_quantized(
|
|||
proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
|
||||
if past_key_value is None:
|
||||
# should use origin attn here
|
||||
attn_weights = torch.matmul(query_states,
|
||||
key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
|
||||
if attention_mask is not None:
|
||||
if q_len == 1: # inference with cache
|
||||
if len(attention_mask.size()) == 4:
|
||||
attention_mask = attention_mask[:, :, -1:, :]
|
||||
else:
|
||||
attention_mask = attention_mask[:, -1:, :]
|
||||
attn_weights = attn_weights + attention_mask
|
||||
attn_weights = torch.max(
|
||||
attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
|
||||
)
|
||||
|
||||
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
if use_cache:
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
k_cache, v_cache = init_fp8_kv_cache(
|
||||
bsz, self.num_heads, kv_seq_len, self.head_dim,
|
||||
device=device
|
||||
)
|
||||
key_states, value_states = append_fp8_kv_cache(k_cache, v_cache,
|
||||
key_states, value_states)
|
||||
past_key_value = (key_states, value_states)
|
||||
|
||||
else:
|
||||
k_cache, v_cache = past_key_value
|
||||
key_states, value_states = append_fp8_kv_cache(k_cache, v_cache,
|
||||
|
|
@ -300,6 +276,7 @@ def baichuan_attention_forward_13b_quantized(
|
|||
torch.tensor(torch.finfo(attn_weights.dtype).min))
|
||||
|
||||
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
|
||||
attn_weights = attn_weights.to(hidden_states.dtype)
|
||||
|
||||
if query_states.size(2) != 1 or device.type != 'xpu':
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
|
@ -488,7 +465,7 @@ def baichuan_13b_gen_alibi_mask(tensor, n_head, max_pos):
|
|||
return alibi_mask
|
||||
|
||||
|
||||
MASK_BLOCK_SIZE = 64
|
||||
MASK_BLOCK_SIZE = 512
|
||||
|
||||
|
||||
def baichuan_13b_get_alibi_mask(self, tensor, seq_length_with_past):
|
||||
|
|
|
|||
Loading…
Reference in a new issue