Fix should_use_fuse_rope error of Qwen1.5-MoE-A2.7B-Chat (#11216)
This commit is contained in:
parent
231b968aba
commit
a6674f5bce
2 changed files with 6 additions and 6 deletions
|
|
@ -54,7 +54,7 @@ from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp
|
||||||
from transformers.models.cohere.modeling_cohere import apply_rotary_pos_emb
|
from transformers.models.cohere.modeling_cohere import apply_rotary_pos_emb
|
||||||
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
|
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
|
||||||
from ipex_llm.transformers.kv import DynamicFp8Cache
|
from ipex_llm.transformers.kv import DynamicFp8Cache
|
||||||
from ipex_llm.transformers.models.qwen2 import should_use_fuse_rope
|
from ipex_llm.transformers.models.utils import should_use_fuse_rope
|
||||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||||
from ipex_llm.utils.common import invalidInputError
|
from ipex_llm.utils.common import invalidInputError
|
||||||
try:
|
try:
|
||||||
|
|
@ -313,7 +313,7 @@ def cohere_attention_forward_origin(
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
device = hidden_states.device
|
device = hidden_states.device
|
||||||
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
use_fuse_rope = should_use_fuse_rope(hidden_states, position_ids, self.training)
|
||||||
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx)
|
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx)
|
||||||
decoding_fast_path = use_decoding_fast_path(self.q_proj,
|
decoding_fast_path = use_decoding_fast_path(self.q_proj,
|
||||||
use_fuse_rope,
|
use_fuse_rope,
|
||||||
|
|
|
||||||
|
|
@ -45,7 +45,7 @@ import torch.utils.checkpoint
|
||||||
import warnings
|
import warnings
|
||||||
from typing import TYPE_CHECKING, Optional, Tuple, Union, Callable, List
|
from typing import TYPE_CHECKING, Optional, Tuple, Union, Callable, List
|
||||||
from ipex_llm.transformers.models.llama import repeat_kv
|
from ipex_llm.transformers.models.llama import repeat_kv
|
||||||
from ipex_llm.transformers.models.qwen2 import should_use_fuse_rope
|
from ipex_llm.transformers.models.utils import should_use_fuse_rope
|
||||||
from ipex_llm.transformers.models.utils import extend_kv_cache, append_kv_cache
|
from ipex_llm.transformers.models.utils import extend_kv_cache, append_kv_cache
|
||||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_cache_freq_xpu
|
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_cache_freq_xpu
|
||||||
from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_36
|
from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_36
|
||||||
|
|
@ -333,7 +333,7 @@ def qwen2moe_attention_forward_quantized(
|
||||||
"Passing `padding_mask` is deprecated and will be removed in v4.37."
|
"Passing `padding_mask` is deprecated and will be removed in v4.37."
|
||||||
"Please make sure use `attention_mask` instead.`"
|
"Please make sure use `attention_mask` instead.`"
|
||||||
)
|
)
|
||||||
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
use_fuse_rope = should_use_fuse_rope(hidden_states, position_ids, self.training)
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
query_states = self.q_proj(hidden_states)
|
query_states = self.q_proj(hidden_states)
|
||||||
|
|
@ -435,7 +435,7 @@ def qwen2moe_attention_forward_origin(
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
use_fuse_rope = should_use_fuse_rope(hidden_states, position_ids, self.training)
|
||||||
|
|
||||||
if "padding_mask" in kwargs:
|
if "padding_mask" in kwargs:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
|
|
@ -592,7 +592,7 @@ def qwen2moe_attention_forward_sdpa(
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
use_fuse_rope = should_use_fuse_rope(hidden_states, position_ids, self.training)
|
||||||
|
|
||||||
if "padding_mask" in kwargs:
|
if "padding_mask" in kwargs:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue