update code for NPU qwen2 (#12094)

* update code

* fix
This commit is contained in:
Ruonan Wang 2024-09-20 00:58:32 -07:00 committed by GitHub
parent db7500bfd4
commit 09b8c80d9d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 27 additions and 24 deletions

View file

@ -85,7 +85,6 @@ class LMHeadLinear(NNFactory):
Returns:
np.ndarray: result
"""
self.prefetchWeights(1, verify_size=False)
self.set_input_tensor(X, 0)
self.elapsed = backend_lib.run(self._mm)
if len(self.out) == 1:

View file

@ -990,6 +990,7 @@ def gen_qwen2_fused_model_forward(prefill_runner, decode_runner):
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if seq_length > 1:
past_key_values_length = 0
from ipex_llm.transformers.npu_models.kv import DynamicFusedNormalCache
@ -1019,6 +1020,9 @@ def gen_qwen2_fused_model_forward(prefill_runner, decode_runner):
past_key_values_length,
sliding_window=self.config.sliding_window,
)
else:
attention_mask = None
position_ids = None
# embed positions
hidden_states = inputs_embeds