Fix should_use_fuse_rope error of Qwen1.5-MoE-A2.7B-Chat (#11216)

This commit is contained in:
binbin Deng 2024-06-05 15:56:10 +08:00 committed by GitHub
parent 231b968aba
commit a6674f5bce
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 6 additions and 6 deletions

View file

@ -54,7 +54,7 @@ from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp
from transformers.models.cohere.modeling_cohere import apply_rotary_pos_emb from transformers.models.cohere.modeling_cohere import apply_rotary_pos_emb
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
from ipex_llm.transformers.kv import DynamicFp8Cache from ipex_llm.transformers.kv import DynamicFp8Cache
from ipex_llm.transformers.models.qwen2 import should_use_fuse_rope from ipex_llm.transformers.models.utils import should_use_fuse_rope
from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.modeling_outputs import BaseModelOutputWithPast
from ipex_llm.utils.common import invalidInputError from ipex_llm.utils.common import invalidInputError
try: try:
@ -313,7 +313,7 @@ def cohere_attention_forward_origin(
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()
device = hidden_states.device device = hidden_states.device
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) use_fuse_rope = should_use_fuse_rope(hidden_states, position_ids, self.training)
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx) enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx)
decoding_fast_path = use_decoding_fast_path(self.q_proj, decoding_fast_path = use_decoding_fast_path(self.q_proj,
use_fuse_rope, use_fuse_rope,

View file

@ -45,7 +45,7 @@ import torch.utils.checkpoint
import warnings import warnings
from typing import TYPE_CHECKING, Optional, Tuple, Union, Callable, List from typing import TYPE_CHECKING, Optional, Tuple, Union, Callable, List
from ipex_llm.transformers.models.llama import repeat_kv from ipex_llm.transformers.models.llama import repeat_kv
from ipex_llm.transformers.models.qwen2 import should_use_fuse_rope from ipex_llm.transformers.models.utils import should_use_fuse_rope
from ipex_llm.transformers.models.utils import extend_kv_cache, append_kv_cache from ipex_llm.transformers.models.utils import extend_kv_cache, append_kv_cache
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_cache_freq_xpu from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_cache_freq_xpu
from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_36 from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_36
@ -333,7 +333,7 @@ def qwen2moe_attention_forward_quantized(
"Passing `padding_mask` is deprecated and will be removed in v4.37." "Passing `padding_mask` is deprecated and will be removed in v4.37."
"Please make sure use `attention_mask` instead.`" "Please make sure use `attention_mask` instead.`"
) )
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) use_fuse_rope = should_use_fuse_rope(hidden_states, position_ids, self.training)
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states) query_states = self.q_proj(hidden_states)
@ -435,7 +435,7 @@ def qwen2moe_attention_forward_origin(
use_cache: bool = False, use_cache: bool = False,
**kwargs, **kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) use_fuse_rope = should_use_fuse_rope(hidden_states, position_ids, self.training)
if "padding_mask" in kwargs: if "padding_mask" in kwargs:
warnings.warn( warnings.warn(
@ -592,7 +592,7 @@ def qwen2moe_attention_forward_sdpa(
use_cache: bool = False, use_cache: bool = False,
**kwargs, **kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) use_fuse_rope = should_use_fuse_rope(hidden_states, position_ids, self.training)
if "padding_mask" in kwargs: if "padding_mask" in kwargs:
warnings.warn( warnings.warn(