optimize qwen rope (#9737)

This commit is contained in:
Yishuo Wang 2023-12-20 17:34:34 +08:00 committed by GitHub
parent 4c032a433e
commit 13ea6330bd

View file

@ -86,13 +86,23 @@ def qwen_attention_forward(
if rotary_pos_emb_list is not None:
cur_len = query.shape[1]
if len(rotary_pos_emb_list) == 1:
rotary_pos_emb = rotary_pos_emb_list[0]
rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb]
rotary_pos_emb = (rotary_pos_emb,) * 2
q_pos_emb, k_pos_emb = rotary_pos_emb
# Slice the pos emb for current inference
query = apply_rotary_pos_emb(query, q_pos_emb)
key = apply_rotary_pos_emb(key, k_pos_emb)
if query.device.type == 'xpu':
cos, sin = rotary_pos_emb_list[0]
cos = cos[:, -cur_len:, :, :]
sin = sin[:, -cur_len:, :, :]
rot_dim = cos.shape[-1]
query_cur = query[..., :rot_dim]
key_cur = key[..., :rot_dim]
torch.ops.torch_ipex.apply_rotary_embedding(query_cur, sin, cos, query_cur)
torch.ops.torch_ipex.apply_rotary_embedding(key_cur, sin, cos, key_cur)
else:
rotary_pos_emb = rotary_pos_emb_list[0]
rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb]
rotary_pos_emb = (rotary_pos_emb,) * 2
q_pos_emb, k_pos_emb = rotary_pos_emb
# Slice the pos emb for current inference
query = apply_rotary_pos_emb(query, q_pos_emb)
key = apply_rotary_pos_emb(key, k_pos_emb)
else:
query_list = []
key_list = []
@ -195,23 +205,6 @@ def qwen_attention_forward(
None,
Exception(_ERROR_INPUT_CPU_QUERY_WITH_FLASH_ATTN_ACTIVATED))
# Remove for efficiency issue on Arc, maybe add later.
# if not self.use_cache_quantization and SUPPORT_TORCH2:
# if attention_mask is not None:
# attention_mask = attention_mask.expand(-1, -1, query.size(2), -1)
# if causal_mask is not None:
# attention_mask = attention_mask.masked_fill(~causal_mask,
# torch.finfo(query.dtype).min)
# else:
# attention_mask = causal_mask
# attn_output = F.scaled_dot_product_attention(
# query, key, value, attn_mask=attention_mask
# ).transpose(1, 2)
# attn_weight = None
# else:
# attn_output, attn_weight = self._attn(
# query, key, value, causal_mask, attention_mask, head_mask
# )
attn_output, attn_weight = self._attn(
query, key, value, causal_mask, attention_mask, head_mask
)