fix baichuan2 13b 2k input (#10267)
This commit is contained in:
parent
7244fd1ba5
commit
cccb02dad1
1 changed files with 34 additions and 57 deletions
|
|
@ -242,71 +242,48 @@ def baichuan_attention_forward_13b_quantized(
|
||||||
proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
)
|
)
|
||||||
|
|
||||||
kv_seq_len = key_states.shape[-2]
|
|
||||||
|
|
||||||
if past_key_value is None:
|
if past_key_value is None:
|
||||||
# should use origin attn here
|
kv_seq_len = key_states.shape[-2]
|
||||||
attn_weights = torch.matmul(query_states,
|
k_cache, v_cache = init_fp8_kv_cache(
|
||||||
key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
bsz, self.num_heads, kv_seq_len, self.head_dim,
|
||||||
|
device=device
|
||||||
if attention_mask is not None:
|
)
|
||||||
if q_len == 1: # inference with cache
|
|
||||||
if len(attention_mask.size()) == 4:
|
|
||||||
attention_mask = attention_mask[:, :, -1:, :]
|
|
||||||
else:
|
|
||||||
attention_mask = attention_mask[:, -1:, :]
|
|
||||||
attn_weights = attn_weights + attention_mask
|
|
||||||
attn_weights = torch.max(
|
|
||||||
attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
|
|
||||||
)
|
|
||||||
|
|
||||||
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
|
|
||||||
attn_output = torch.matmul(attn_weights, value_states)
|
|
||||||
|
|
||||||
if use_cache:
|
|
||||||
k_cache, v_cache = init_fp8_kv_cache(
|
|
||||||
bsz, self.num_heads, kv_seq_len, self.head_dim,
|
|
||||||
device=device
|
|
||||||
)
|
|
||||||
key_states, value_states = append_fp8_kv_cache(k_cache, v_cache,
|
|
||||||
key_states, value_states)
|
|
||||||
past_key_value = (key_states, value_states)
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
k_cache, v_cache = past_key_value
|
k_cache, v_cache = past_key_value
|
||||||
key_states, value_states = append_fp8_kv_cache(k_cache, v_cache,
|
key_states, value_states = append_fp8_kv_cache(k_cache, v_cache,
|
||||||
key_states, value_states)
|
key_states, value_states)
|
||||||
past_key_value = (key_states, value_states)
|
past_key_value = (key_states, value_states)
|
||||||
|
|
||||||
if query_states.size(2) != 1 or device.type != 'xpu':
|
if query_states.size(2) != 1 or device.type != 'xpu':
|
||||||
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
|
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
|
||||||
query_states.dtype)
|
query_states.dtype)
|
||||||
|
|
||||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3))
|
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3))
|
||||||
else:
|
else:
|
||||||
import linear_q4_0
|
import linear_q4_0
|
||||||
attn_weights = linear_q4_0.query_key_fp8_matmul(query_states, key_states)
|
attn_weights = linear_q4_0.query_key_fp8_matmul(query_states, key_states)
|
||||||
|
|
||||||
attn_weights = attn_weights / math.sqrt(self.head_dim)
|
attn_weights = attn_weights / math.sqrt(self.head_dim)
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
if q_len == 1: # inference with cache
|
if q_len == 1: # inference with cache
|
||||||
if len(attention_mask.size()) == 4:
|
if len(attention_mask.size()) == 4:
|
||||||
attention_mask = attention_mask[:, :, -1:, :]
|
attention_mask = attention_mask[:, :, -1:, :]
|
||||||
else:
|
else:
|
||||||
attention_mask = attention_mask[:, -1:, :]
|
attention_mask = attention_mask[:, -1:, :]
|
||||||
attn_weights = attn_weights + attention_mask
|
attn_weights = attn_weights + attention_mask
|
||||||
attn_weights = torch.max(attn_weights,
|
attn_weights = torch.max(attn_weights,
|
||||||
torch.tensor(torch.finfo(attn_weights.dtype).min))
|
torch.tensor(torch.finfo(attn_weights.dtype).min))
|
||||||
|
|
||||||
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
|
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
|
||||||
|
attn_weights = attn_weights.to(hidden_states.dtype)
|
||||||
|
|
||||||
if query_states.size(2) != 1 or device.type != 'xpu':
|
if query_states.size(2) != 1 or device.type != 'xpu':
|
||||||
attn_output = torch.matmul(attn_weights, value_states)
|
attn_output = torch.matmul(attn_weights, value_states)
|
||||||
else:
|
else:
|
||||||
import linear_q4_0
|
import linear_q4_0
|
||||||
attn_output = linear_q4_0.attn_value_fp8_matmul(attn_weights,
|
attn_output = linear_q4_0.attn_value_fp8_matmul(attn_weights,
|
||||||
value_states.transpose(-1, -2))
|
value_states.transpose(-1, -2))
|
||||||
|
|
||||||
attn_output = attn_output.transpose(1, 2)
|
attn_output = attn_output.transpose(1, 2)
|
||||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||||
|
|
@ -488,7 +465,7 @@ def baichuan_13b_gen_alibi_mask(tensor, n_head, max_pos):
|
||||||
return alibi_mask
|
return alibi_mask
|
||||||
|
|
||||||
|
|
||||||
MASK_BLOCK_SIZE = 64
|
MASK_BLOCK_SIZE = 512
|
||||||
|
|
||||||
|
|
||||||
def baichuan_13b_get_alibi_mask(self, tensor, seq_length_with_past):
|
def baichuan_13b_get_alibi_mask(self, tensor, seq_length_with_past):
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue