diff --git a/python/llm/src/ipex_llm/transformers/low_bit_linear.py b/python/llm/src/ipex_llm/transformers/low_bit_linear.py index c30ca4a2..e182394a 100644 --- a/python/llm/src/ipex_llm/transformers/low_bit_linear.py +++ b/python/llm/src/ipex_llm/transformers/low_bit_linear.py @@ -73,6 +73,7 @@ MOFQ4 = ggml_tensor_qtype["mixed_fp4"] MOFQ8 = ggml_tensor_qtype["mixed_fp8"] FP8E5 = ggml_tensor_qtype["fp8_e5m2"] FP6 = ggml_tensor_qtype["fp6"] +FP16 = ggml_tensor_qtype["fp16"] IQ2_XXS = ggml_tensor_qtype["gguf_iq2_xxs"] IQ2_XS = ggml_tensor_qtype["gguf_iq2_xs"] Q2_K = ggml_tensor_qtype["q2_k"] diff --git a/python/llm/src/ipex_llm/transformers/models/mistral.py b/python/llm/src/ipex_llm/transformers/models/mistral.py index 689d9108..358cc9cc 100644 --- a/python/llm/src/ipex_llm/transformers/models/mistral.py +++ b/python/llm/src/ipex_llm/transformers/models/mistral.py @@ -63,6 +63,8 @@ try: except ImportError: Cache = Tuple[torch.Tensor] +from ipex_llm.transformers.low_bit_linear import FP6, FP16 + import os KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)) @@ -271,6 +273,9 @@ def mistral_attention_forward_quantized( original_dtype = hidden_states.dtype use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) + if self.q_proj.qtype not in [FP6, FP16]: + use_fuse_rope = False + enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value) decoding_fast_path = use_decoding_fast_path(self.q_proj, use_fuse_rope, @@ -476,6 +481,9 @@ def mistral_attention_forward_original( original_dtype = hidden_states.dtype use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) + if self.q_proj.qtype not in [FP6, FP16]: + use_fuse_rope = False + enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value) decoding_fast_path = use_decoding_fast_path(self.q_proj, use_fuse_rope, @@ -699,6 +707,9 @@ def mistral_attention_forward_4_36_quantized( original_dtype = hidden_states.dtype use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) + if self.q_proj.qtype not in [FP6, FP16]: + use_fuse_rope = False + enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, seq_len=q_len) decoding_fast_path = use_decoding_fast_path(self.q_proj, @@ -917,6 +928,9 @@ def mistral_attention_forward_4_36_original( use_compresskv = isinstance(past_key_value, DynamicCompressCache) use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) + if self.q_proj.qtype not in [FP6, FP16]: + use_fuse_rope = False + enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, q_len) @@ -1175,6 +1189,9 @@ def mistral_attention_forward_4_39_original( use_compresskv = isinstance(past_key_value, DynamicCompressCache) use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) + if self.q_proj.qtype not in [FP6, FP16]: + use_fuse_rope = False + enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, q_len) decoding_fast_path = use_decoding_fast_path(self.q_proj,