From a6674f5bced7c69d036f830a2cebbb19401cbb82 Mon Sep 17 00:00:00 2001 From: binbin Deng <108676127+plusbang@users.noreply.github.com> Date: Wed, 5 Jun 2024 15:56:10 +0800 Subject: [PATCH] Fix `should_use_fuse_rope` error of Qwen1.5-MoE-A2.7B-Chat (#11216) --- python/llm/src/ipex_llm/transformers/models/cohere.py | 4 ++-- python/llm/src/ipex_llm/transformers/models/qwen2_moe.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/models/cohere.py b/python/llm/src/ipex_llm/transformers/models/cohere.py index 9ee4f142..5e3437e3 100644 --- a/python/llm/src/ipex_llm/transformers/models/cohere.py +++ b/python/llm/src/ipex_llm/transformers/models/cohere.py @@ -54,7 +54,7 @@ from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp from transformers.models.cohere.modeling_cohere import apply_rotary_pos_emb from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache from ipex_llm.transformers.kv import DynamicFp8Cache -from ipex_llm.transformers.models.qwen2 import should_use_fuse_rope +from ipex_llm.transformers.models.utils import should_use_fuse_rope from transformers.modeling_outputs import BaseModelOutputWithPast from ipex_llm.utils.common import invalidInputError try: @@ -313,7 +313,7 @@ def cohere_attention_forward_origin( ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() device = hidden_states.device - use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) + use_fuse_rope = should_use_fuse_rope(hidden_states, position_ids, self.training) enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx) decoding_fast_path = use_decoding_fast_path(self.q_proj, use_fuse_rope, diff --git a/python/llm/src/ipex_llm/transformers/models/qwen2_moe.py b/python/llm/src/ipex_llm/transformers/models/qwen2_moe.py index 9f14ca08..7d9d2b36 100644 --- a/python/llm/src/ipex_llm/transformers/models/qwen2_moe.py +++ b/python/llm/src/ipex_llm/transformers/models/qwen2_moe.py @@ -45,7 +45,7 @@ import torch.utils.checkpoint import warnings from typing import TYPE_CHECKING, Optional, Tuple, Union, Callable, List from ipex_llm.transformers.models.llama import repeat_kv -from ipex_llm.transformers.models.qwen2 import should_use_fuse_rope +from ipex_llm.transformers.models.utils import should_use_fuse_rope from ipex_llm.transformers.models.utils import extend_kv_cache, append_kv_cache from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_cache_freq_xpu from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_36 @@ -333,7 +333,7 @@ def qwen2moe_attention_forward_quantized( "Passing `padding_mask` is deprecated and will be removed in v4.37." "Please make sure use `attention_mask` instead.`" ) - use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) + use_fuse_rope = should_use_fuse_rope(hidden_states, position_ids, self.training) bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) @@ -435,7 +435,7 @@ def qwen2moe_attention_forward_origin( use_cache: bool = False, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) + use_fuse_rope = should_use_fuse_rope(hidden_states, position_ids, self.training) if "padding_mask" in kwargs: warnings.warn( @@ -592,7 +592,7 @@ def qwen2moe_attention_forward_sdpa( use_cache: bool = False, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) + use_fuse_rope = should_use_fuse_rope(hidden_states, position_ids, self.training) if "padding_mask" in kwargs: warnings.warn(