Fix qwen's position_ids no enough (#10572)
* fix position_ids * fix position_ids
This commit is contained in:
parent
52a2135d83
commit
5963239b46
2 changed files with 7 additions and 8 deletions
|
|
@ -595,7 +595,6 @@ def _optimize_pre(model):
|
||||||
from ipex_llm.transformers.models.bert import merge_qkv
|
from ipex_llm.transformers.models.bert import merge_qkv
|
||||||
model.apply(merge_qkv)
|
model.apply(merge_qkv)
|
||||||
if model.config.model_type == "qwen":
|
if model.config.model_type == "qwen":
|
||||||
position_ids = torch.arange(0, model.config.max_position_embeddings)
|
|
||||||
rope_base = model.config.rotary_emb_base
|
rope_base = model.config.rotary_emb_base
|
||||||
from accelerate.big_modeling import init_empty_weights
|
from accelerate.big_modeling import init_empty_weights
|
||||||
|
|
||||||
|
|
@ -625,7 +624,6 @@ def _optimize_pre(model):
|
||||||
module.q_proj = q_proj
|
module.q_proj = q_proj
|
||||||
module.k_proj = k_proj
|
module.k_proj = k_proj
|
||||||
module.v_proj = v_proj
|
module.v_proj = v_proj
|
||||||
module.position_ids = position_ids
|
|
||||||
module.rope_base = rope_base
|
module.rope_base = rope_base
|
||||||
del module.c_attn
|
del module.c_attn
|
||||||
model.apply(split_qkv_proj_func)
|
model.apply(split_qkv_proj_func)
|
||||||
|
|
|
||||||
|
|
@ -136,6 +136,8 @@ def qwen_attention_forward_original(
|
||||||
device = hidden_states.device
|
device = hidden_states.device
|
||||||
# for flash attention
|
# for flash attention
|
||||||
original_dtype = hidden_states.dtype
|
original_dtype = hidden_states.dtype
|
||||||
|
position_ids = rotary_pos_emb_list[-1] # the last one is posisiton_ids
|
||||||
|
rotary_pos_emb_list = rotary_pos_emb_list[:-1]
|
||||||
|
|
||||||
use_fuse_rope = should_use_fuse_rope(self, hidden_states)
|
use_fuse_rope = should_use_fuse_rope(self, hidden_states)
|
||||||
qtype_check = decoding_fast_path_qtype_check(self.q_proj)
|
qtype_check = decoding_fast_path_qtype_check(self.q_proj)
|
||||||
|
|
@ -147,8 +149,6 @@ def qwen_attention_forward_original(
|
||||||
cache_v = cache_v.transpose(1, 2)
|
cache_v = cache_v.transpose(1, 2)
|
||||||
|
|
||||||
kv_seq_len = cache_k.shape[-2]
|
kv_seq_len = cache_k.shape[-2]
|
||||||
self.position_ids = self.position_ids.to(device)
|
|
||||||
position_ids = self.position_ids[kv_seq_len]
|
|
||||||
base = self.rope_base
|
base = self.rope_base
|
||||||
if is_enough_kv_cache_room(layer_past, kv_seq_len):
|
if is_enough_kv_cache_room(layer_past, kv_seq_len):
|
||||||
new_cache_k, new_cache_v = extend_kv_cache(bsz,
|
new_cache_k, new_cache_v = extend_kv_cache(bsz,
|
||||||
|
|
@ -182,7 +182,7 @@ def qwen_attention_forward_original(
|
||||||
# query = self._split_heads(query, self.num_heads, self.head_dim)
|
# query = self._split_heads(query, self.num_heads, self.head_dim)
|
||||||
# key = self._split_heads(key, self.num_heads, self.head_dim)
|
# key = self._split_heads(key, self.num_heads, self.head_dim)
|
||||||
# value = self._split_heads(value, self.num_heads, self.head_dim)
|
# value = self._split_heads(value, self.num_heads, self.head_dim)
|
||||||
if rotary_pos_emb_list is not None:
|
if len(rotary_pos_emb_list) != 0:
|
||||||
cur_len = query.shape[1]
|
cur_len = query.shape[1]
|
||||||
if len(rotary_pos_emb_list) == 1:
|
if len(rotary_pos_emb_list) == 1:
|
||||||
rotary_pos_emb = rotary_pos_emb_list[0]
|
rotary_pos_emb = rotary_pos_emb_list[0]
|
||||||
|
|
@ -332,6 +332,8 @@ def qwen_attention_forward_quantized(
|
||||||
|
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
device = hidden_states.device
|
device = hidden_states.device
|
||||||
|
position_ids = rotary_pos_emb_list[-1] # the last one is posisiton_ids
|
||||||
|
rotary_pos_emb_list = rotary_pos_emb_list[:-1]
|
||||||
|
|
||||||
use_fuse_rope = should_use_fuse_rope(self, hidden_states)
|
use_fuse_rope = should_use_fuse_rope(self, hidden_states)
|
||||||
# qtype_check = decoding_fast_path_qtype_check(self.q_proj)
|
# qtype_check = decoding_fast_path_qtype_check(self.q_proj)
|
||||||
|
|
@ -349,7 +351,6 @@ def qwen_attention_forward_quantized(
|
||||||
device=device
|
device=device
|
||||||
)
|
)
|
||||||
|
|
||||||
position_ids = self.position_ids[self.kv_seq_len].to(device)
|
|
||||||
base = self.rope_base
|
base = self.rope_base
|
||||||
|
|
||||||
args = [hidden_states, self.q_proj.weight.data, self.k_proj.weight.data,
|
args = [hidden_states, self.q_proj.weight.data, self.k_proj.weight.data,
|
||||||
|
|
@ -599,7 +600,7 @@ def qwen_model_forward(
|
||||||
if self.use_cache_quantization:
|
if self.use_cache_quantization:
|
||||||
past_length = past_key_values[0][0][0].size(2)
|
past_length = past_key_values[0][0][0].size(2)
|
||||||
else:
|
else:
|
||||||
past_length = past_key_values[0][0].size(-2)
|
past_length = past_key_values[0][0].size(1)
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
position_ids = torch.arange(
|
position_ids = torch.arange(
|
||||||
past_length,
|
past_length,
|
||||||
|
|
@ -651,7 +652,7 @@ def qwen_model_forward(
|
||||||
self.rotary_emb._ntk_alpha_cached_list = ntk_alpha_list
|
self.rotary_emb._ntk_alpha_cached_list = ntk_alpha_list
|
||||||
rotary_pos_emb_list = [
|
rotary_pos_emb_list = [
|
||||||
self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha) for ntk_alpha in ntk_alpha_list
|
self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha) for ntk_alpha in ntk_alpha_list
|
||||||
]
|
] + [position_ids]
|
||||||
|
|
||||||
hidden_states = self.drop(hidden_states)
|
hidden_states = self.drop(hidden_states)
|
||||||
output_shape = input_shape + (hidden_states.size(-1),)
|
output_shape = input_shape + (hidden_states.size(-1),)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue