Set mistral fuse rope to false except fp6 & fp16 (#11765)
* set mistral fuse rope to false except fp6 & fp16 * lint * lint --------- Co-authored-by: ATMxsp01 <shou.xu@intel.com>
This commit is contained in:
parent
8db34057b4
commit
1b05caba2b
2 changed files with 18 additions and 0 deletions
|
|
@ -73,6 +73,7 @@ MOFQ4 = ggml_tensor_qtype["mixed_fp4"]
|
||||||
MOFQ8 = ggml_tensor_qtype["mixed_fp8"]
|
MOFQ8 = ggml_tensor_qtype["mixed_fp8"]
|
||||||
FP8E5 = ggml_tensor_qtype["fp8_e5m2"]
|
FP8E5 = ggml_tensor_qtype["fp8_e5m2"]
|
||||||
FP6 = ggml_tensor_qtype["fp6"]
|
FP6 = ggml_tensor_qtype["fp6"]
|
||||||
|
FP16 = ggml_tensor_qtype["fp16"]
|
||||||
IQ2_XXS = ggml_tensor_qtype["gguf_iq2_xxs"]
|
IQ2_XXS = ggml_tensor_qtype["gguf_iq2_xxs"]
|
||||||
IQ2_XS = ggml_tensor_qtype["gguf_iq2_xs"]
|
IQ2_XS = ggml_tensor_qtype["gguf_iq2_xs"]
|
||||||
Q2_K = ggml_tensor_qtype["q2_k"]
|
Q2_K = ggml_tensor_qtype["q2_k"]
|
||||||
|
|
|
||||||
|
|
@ -63,6 +63,8 @@ try:
|
||||||
except ImportError:
|
except ImportError:
|
||||||
Cache = Tuple[torch.Tensor]
|
Cache = Tuple[torch.Tensor]
|
||||||
|
|
||||||
|
from ipex_llm.transformers.low_bit_linear import FP6, FP16
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))
|
KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))
|
||||||
|
|
@ -271,6 +273,9 @@ def mistral_attention_forward_quantized(
|
||||||
original_dtype = hidden_states.dtype
|
original_dtype = hidden_states.dtype
|
||||||
|
|
||||||
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
||||||
|
if self.q_proj.qtype not in [FP6, FP16]:
|
||||||
|
use_fuse_rope = False
|
||||||
|
|
||||||
enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value)
|
enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value)
|
||||||
decoding_fast_path = use_decoding_fast_path(self.q_proj,
|
decoding_fast_path = use_decoding_fast_path(self.q_proj,
|
||||||
use_fuse_rope,
|
use_fuse_rope,
|
||||||
|
|
@ -476,6 +481,9 @@ def mistral_attention_forward_original(
|
||||||
original_dtype = hidden_states.dtype
|
original_dtype = hidden_states.dtype
|
||||||
|
|
||||||
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
||||||
|
if self.q_proj.qtype not in [FP6, FP16]:
|
||||||
|
use_fuse_rope = False
|
||||||
|
|
||||||
enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value)
|
enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value)
|
||||||
decoding_fast_path = use_decoding_fast_path(self.q_proj,
|
decoding_fast_path = use_decoding_fast_path(self.q_proj,
|
||||||
use_fuse_rope,
|
use_fuse_rope,
|
||||||
|
|
@ -699,6 +707,9 @@ def mistral_attention_forward_4_36_quantized(
|
||||||
original_dtype = hidden_states.dtype
|
original_dtype = hidden_states.dtype
|
||||||
|
|
||||||
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
||||||
|
if self.q_proj.qtype not in [FP6, FP16]:
|
||||||
|
use_fuse_rope = False
|
||||||
|
|
||||||
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx,
|
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx,
|
||||||
seq_len=q_len)
|
seq_len=q_len)
|
||||||
decoding_fast_path = use_decoding_fast_path(self.q_proj,
|
decoding_fast_path = use_decoding_fast_path(self.q_proj,
|
||||||
|
|
@ -917,6 +928,9 @@ def mistral_attention_forward_4_36_original(
|
||||||
use_compresskv = isinstance(past_key_value, DynamicCompressCache)
|
use_compresskv = isinstance(past_key_value, DynamicCompressCache)
|
||||||
|
|
||||||
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
||||||
|
if self.q_proj.qtype not in [FP6, FP16]:
|
||||||
|
use_fuse_rope = False
|
||||||
|
|
||||||
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value,
|
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value,
|
||||||
self.layer_idx,
|
self.layer_idx,
|
||||||
q_len)
|
q_len)
|
||||||
|
|
@ -1175,6 +1189,9 @@ def mistral_attention_forward_4_39_original(
|
||||||
use_compresskv = isinstance(past_key_value, DynamicCompressCache)
|
use_compresskv = isinstance(past_key_value, DynamicCompressCache)
|
||||||
|
|
||||||
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
||||||
|
if self.q_proj.qtype not in [FP6, FP16]:
|
||||||
|
use_fuse_rope = False
|
||||||
|
|
||||||
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx,
|
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx,
|
||||||
q_len)
|
q_len)
|
||||||
decoding_fast_path = use_decoding_fast_path(self.q_proj,
|
decoding_fast_path = use_decoding_fast_path(self.q_proj,
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue