diff --git a/python/llm/src/bigdl/llm/transformers/models/qwen.py b/python/llm/src/bigdl/llm/transformers/models/qwen.py index 5e4f5dcf..a5aada9e 100644 --- a/python/llm/src/bigdl/llm/transformers/models/qwen.py +++ b/python/llm/src/bigdl/llm/transformers/models/qwen.py @@ -86,13 +86,23 @@ def qwen_attention_forward( if rotary_pos_emb_list is not None: cur_len = query.shape[1] if len(rotary_pos_emb_list) == 1: - rotary_pos_emb = rotary_pos_emb_list[0] - rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb] - rotary_pos_emb = (rotary_pos_emb,) * 2 - q_pos_emb, k_pos_emb = rotary_pos_emb - # Slice the pos emb for current inference - query = apply_rotary_pos_emb(query, q_pos_emb) - key = apply_rotary_pos_emb(key, k_pos_emb) + if query.device.type == 'xpu': + cos, sin = rotary_pos_emb_list[0] + cos = cos[:, -cur_len:, :, :] + sin = sin[:, -cur_len:, :, :] + rot_dim = cos.shape[-1] + query_cur = query[..., :rot_dim] + key_cur = key[..., :rot_dim] + torch.ops.torch_ipex.apply_rotary_embedding(query_cur, sin, cos, query_cur) + torch.ops.torch_ipex.apply_rotary_embedding(key_cur, sin, cos, key_cur) + else: + rotary_pos_emb = rotary_pos_emb_list[0] + rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb] + rotary_pos_emb = (rotary_pos_emb,) * 2 + q_pos_emb, k_pos_emb = rotary_pos_emb + # Slice the pos emb for current inference + query = apply_rotary_pos_emb(query, q_pos_emb) + key = apply_rotary_pos_emb(key, k_pos_emb) else: query_list = [] key_list = [] @@ -195,23 +205,6 @@ def qwen_attention_forward( None, Exception(_ERROR_INPUT_CPU_QUERY_WITH_FLASH_ATTN_ACTIVATED)) - # Remove for efficiency issue on Arc, maybe add later. - # if not self.use_cache_quantization and SUPPORT_TORCH2: - # if attention_mask is not None: - # attention_mask = attention_mask.expand(-1, -1, query.size(2), -1) - # if causal_mask is not None: - # attention_mask = attention_mask.masked_fill(~causal_mask, - # torch.finfo(query.dtype).min) - # else: - # attention_mask = causal_mask - # attn_output = F.scaled_dot_product_attention( - # query, key, value, attn_mask=attention_mask - # ).transpose(1, 2) - # attn_weight = None - # else: - # attn_output, attn_weight = self._attn( - # query, key, value, causal_mask, attention_mask, head_mask - # ) attn_output, attn_weight = self._attn( query, key, value, causal_mask, attention_mask, head_mask )