LLM: fix qwen2 (#10356)

This commit is contained in:
Ruonan Wang 2024-03-11 09:29:08 +08:00 committed by GitHub
parent f4cef95690
commit be29833b2b

View file

@ -262,7 +262,8 @@ def qwen2_attention_forward_origin(
import linear_q4_0 import linear_q4_0
args = [hidden_states, self.q_proj.weight, self.k_proj.weight, self.v_proj.weight, args = [hidden_states, self.q_proj.weight, self.k_proj.weight, self.v_proj.weight,
self.q_proj.bias, self.k_proj.bias, self.v_proj.bias, position_ids, cache_k, self.q_proj.bias, self.k_proj.bias, self.v_proj.bias, position_ids, cache_k,
cache_v, self.q_proj.weight.qtype, kv_seq_len, self.head_dim, self.rotary_emb.base] cache_v, self.q_proj.weight.qtype, self.v_proj.weight.qtype, kv_seq_len,
self.head_dim, self.rotary_emb.base]
query_states, key_states, value_states = linear_q4_0.forward_qkv_bias(*args) query_states, key_states, value_states = linear_q4_0.forward_qkv_bias(*args)
kv_seq_len += 1 kv_seq_len += 1
if self.layer_idx == 0: if self.layer_idx == 0: