Fix mistral forward_qkv in q4_0 (#11781)

* Fix mistral forward_qkv without self.rotary_emb.base in q4_0.
* Replace apply_rotary_pos_emb_no_cache_xpu with rotary_half_inplaced.
* Revert https://github.com/intel-analytics/ipex-llm/pull/11765
This commit is contained in:
Qiyuan Gong 2024-08-13 16:48:19 +08:00 committed by GitHub
parent 70c828b87c
commit 3998de14f0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -48,8 +48,7 @@ from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, a
from ipex_llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_cache, \
restore_fp8_kv_cache, use_quantize_kv_cache, should_use_compresskv, \
get_compresskv_attn_mask
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, \
apply_rotary_pos_emb_no_cache_xpu
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb
from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \
is_enough_kv_cache_room_4_36
from ipex_llm.transformers.low_bit_linear import SYM_INT4, FP8E5, IQ2_XXS
@ -64,7 +63,6 @@ try:
except ImportError:
Cache = Tuple[torch.Tensor]
from ipex_llm.transformers.low_bit_linear import FP6, FP16
import os
@ -274,8 +272,6 @@ def mistral_attention_forward_quantized(
original_dtype = hidden_states.dtype
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
if self.q_proj.qtype not in [FP6, FP16]:
use_fuse_rope = False
enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value)
decoding_fast_path = use_decoding_fast_path(self.q_proj,
@ -304,7 +300,8 @@ def mistral_attention_forward_quantized(
self.q_proj.weight.qtype,
self.v_proj.weight.qtype,
0,
self.head_dim)
self.head_dim,
self.rotary_emb.base)
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
@ -321,11 +318,9 @@ def mistral_attention_forward_quantized(
kv_seq_len += past_key_value[0].shape[-2]
if use_fuse_rope:
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
key_states,
position_ids,
"mistral",
self.config.rope_theta)
import xe_addons
xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
query_states, key_states)
else:
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
@ -482,8 +477,6 @@ def mistral_attention_forward_original(
original_dtype = hidden_states.dtype
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
if self.q_proj.qtype not in [FP6, FP16]:
use_fuse_rope = False
enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value)
decoding_fast_path = use_decoding_fast_path(self.q_proj,
@ -506,7 +499,8 @@ def mistral_attention_forward_original(
self.q_proj.weight.qtype,
self.v_proj.weight.qtype,
kv_seq_len,
self.head_dim)
self.head_dim,
self.rotary_emb.base)
kv_seq_len += 1
else:
@ -542,11 +536,9 @@ def mistral_attention_forward_original(
kv_seq_len += past_key_value[0].shape[-2]
if use_fuse_rope:
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
key_states,
position_ids,
"mistral",
self.config.rope_theta)
import xe_addons
xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
query_states, key_states)
else:
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
@ -708,8 +700,6 @@ def mistral_attention_forward_4_36_quantized(
original_dtype = hidden_states.dtype
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
if self.q_proj.qtype not in [FP6, FP16]:
use_fuse_rope = False
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx,
seq_len=q_len)
@ -739,7 +729,8 @@ def mistral_attention_forward_4_36_quantized(
self.q_proj.weight.qtype,
self.v_proj.weight.qtype,
0,
self.head_dim)
self.head_dim,
self.rotary_emb.base)
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
@ -765,11 +756,9 @@ def mistral_attention_forward_4_36_quantized(
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
if use_fuse_rope:
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
key_states,
position_ids,
"mistral",
self.config.rope_theta)
import xe_addons
xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
query_states, key_states)
else:
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
@ -928,8 +917,6 @@ def mistral_attention_forward_4_36_original(
use_compresskv = isinstance(past_key_value, DynamicCompressCache)
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
if self.q_proj.qtype not in [FP6, FP16]:
use_fuse_rope = False
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value,
self.layer_idx,
@ -958,7 +945,8 @@ def mistral_attention_forward_4_36_original(
self.q_proj.weight.qtype,
self.v_proj.weight.qtype,
kv_seq_len,
self.head_dim)
self.head_dim,
self.rotary_emb.base)
kv_seq_len += 1
# update past_key_value's seem_tokens and kv caches.
@ -1011,11 +999,9 @@ def mistral_attention_forward_4_36_original(
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
if use_fuse_rope:
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
key_states,
position_ids,
"mistral",
self.config.rope_theta)
import xe_addons
xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
query_states, key_states)
else:
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
@ -1189,8 +1175,6 @@ def mistral_attention_forward_4_39_original(
use_compresskv = isinstance(past_key_value, DynamicCompressCache)
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
if self.q_proj.qtype not in [FP6, FP16]:
use_fuse_rope = False
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx,
q_len)
@ -1218,7 +1202,8 @@ def mistral_attention_forward_4_39_original(
self.q_proj.weight.qtype,
self.v_proj.weight.qtype,
kv_seq_len,
self.head_dim)
self.head_dim,
self.rotary_emb.base)
kv_seq_len += 1
# update past_key_value's seem_tokens and kv caches.
@ -1270,11 +1255,9 @@ def mistral_attention_forward_4_39_original(
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
if use_fuse_rope:
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
key_states,
position_ids,
"mistral",
self.config.rope_theta)
import xe_addons
xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
query_states, key_states)
else:
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,