optimize qwen rope (#9737)
This commit is contained in:
parent
4c032a433e
commit
13ea6330bd
1 changed files with 17 additions and 24 deletions
|
|
@ -86,13 +86,23 @@ def qwen_attention_forward(
|
|||
if rotary_pos_emb_list is not None:
|
||||
cur_len = query.shape[1]
|
||||
if len(rotary_pos_emb_list) == 1:
|
||||
rotary_pos_emb = rotary_pos_emb_list[0]
|
||||
rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb]
|
||||
rotary_pos_emb = (rotary_pos_emb,) * 2
|
||||
q_pos_emb, k_pos_emb = rotary_pos_emb
|
||||
# Slice the pos emb for current inference
|
||||
query = apply_rotary_pos_emb(query, q_pos_emb)
|
||||
key = apply_rotary_pos_emb(key, k_pos_emb)
|
||||
if query.device.type == 'xpu':
|
||||
cos, sin = rotary_pos_emb_list[0]
|
||||
cos = cos[:, -cur_len:, :, :]
|
||||
sin = sin[:, -cur_len:, :, :]
|
||||
rot_dim = cos.shape[-1]
|
||||
query_cur = query[..., :rot_dim]
|
||||
key_cur = key[..., :rot_dim]
|
||||
torch.ops.torch_ipex.apply_rotary_embedding(query_cur, sin, cos, query_cur)
|
||||
torch.ops.torch_ipex.apply_rotary_embedding(key_cur, sin, cos, key_cur)
|
||||
else:
|
||||
rotary_pos_emb = rotary_pos_emb_list[0]
|
||||
rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb]
|
||||
rotary_pos_emb = (rotary_pos_emb,) * 2
|
||||
q_pos_emb, k_pos_emb = rotary_pos_emb
|
||||
# Slice the pos emb for current inference
|
||||
query = apply_rotary_pos_emb(query, q_pos_emb)
|
||||
key = apply_rotary_pos_emb(key, k_pos_emb)
|
||||
else:
|
||||
query_list = []
|
||||
key_list = []
|
||||
|
|
@ -195,23 +205,6 @@ def qwen_attention_forward(
|
|||
None,
|
||||
Exception(_ERROR_INPUT_CPU_QUERY_WITH_FLASH_ATTN_ACTIVATED))
|
||||
|
||||
# Remove for efficiency issue on Arc, maybe add later.
|
||||
# if not self.use_cache_quantization and SUPPORT_TORCH2:
|
||||
# if attention_mask is not None:
|
||||
# attention_mask = attention_mask.expand(-1, -1, query.size(2), -1)
|
||||
# if causal_mask is not None:
|
||||
# attention_mask = attention_mask.masked_fill(~causal_mask,
|
||||
# torch.finfo(query.dtype).min)
|
||||
# else:
|
||||
# attention_mask = causal_mask
|
||||
# attn_output = F.scaled_dot_product_attention(
|
||||
# query, key, value, attn_mask=attention_mask
|
||||
# ).transpose(1, 2)
|
||||
# attn_weight = None
|
||||
# else:
|
||||
# attn_output, attn_weight = self._attn(
|
||||
# query, key, value, causal_mask, attention_mask, head_mask
|
||||
# )
|
||||
attn_output, attn_weight = self._attn(
|
||||
query, key, value, causal_mask, attention_mask, head_mask
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in a new issue