parent
db7500bfd4
commit
09b8c80d9d
2 changed files with 27 additions and 24 deletions
|
|
@ -85,7 +85,6 @@ class LMHeadLinear(NNFactory):
|
||||||
Returns:
|
Returns:
|
||||||
np.ndarray: result
|
np.ndarray: result
|
||||||
"""
|
"""
|
||||||
self.prefetchWeights(1, verify_size=False)
|
|
||||||
self.set_input_tensor(X, 0)
|
self.set_input_tensor(X, 0)
|
||||||
self.elapsed = backend_lib.run(self._mm)
|
self.elapsed = backend_lib.run(self._mm)
|
||||||
if len(self.out) == 1:
|
if len(self.out) == 1:
|
||||||
|
|
|
||||||
|
|
@ -990,35 +990,39 @@ def gen_qwen2_fused_model_forward(prefill_runner, decode_runner):
|
||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
past_key_values_length = 0
|
if seq_length > 1:
|
||||||
|
past_key_values_length = 0
|
||||||
|
|
||||||
from ipex_llm.transformers.npu_models.kv import DynamicFusedNormalCache
|
from ipex_llm.transformers.npu_models.kv import DynamicFusedNormalCache
|
||||||
|
|
||||||
if use_cache and not isinstance(past_key_values, DynamicFusedNormalCache):
|
if use_cache and not isinstance(past_key_values, DynamicFusedNormalCache):
|
||||||
past_key_values = DynamicFusedNormalCache.from_legacy_cache(past_key_values)
|
past_key_values = DynamicFusedNormalCache.from_legacy_cache(past_key_values)
|
||||||
past_key_values_length = past_key_values.get_seq_length()
|
past_key_values_length = past_key_values.get_seq_length()
|
||||||
|
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||||
position_ids = torch.arange(
|
position_ids = torch.arange(
|
||||||
|
past_key_values_length,
|
||||||
|
seq_length + past_key_values_length,
|
||||||
|
dtype=torch.long,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||||
|
else:
|
||||||
|
position_ids = position_ids.view(-1, seq_length).long()
|
||||||
|
|
||||||
|
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
||||||
|
|
||||||
|
attention_mask = _prepare_4d_causal_attention_mask(
|
||||||
|
attention_mask,
|
||||||
|
(batch_size, seq_length),
|
||||||
|
inputs_embeds,
|
||||||
past_key_values_length,
|
past_key_values_length,
|
||||||
seq_length + past_key_values_length,
|
sliding_window=self.config.sliding_window,
|
||||||
dtype=torch.long,
|
|
||||||
device=device,
|
|
||||||
)
|
)
|
||||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
|
||||||
else:
|
else:
|
||||||
position_ids = position_ids.view(-1, seq_length).long()
|
attention_mask = None
|
||||||
|
position_ids = None
|
||||||
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
|
||||||
|
|
||||||
attention_mask = _prepare_4d_causal_attention_mask(
|
|
||||||
attention_mask,
|
|
||||||
(batch_size, seq_length),
|
|
||||||
inputs_embeds,
|
|
||||||
past_key_values_length,
|
|
||||||
sliding_window=self.config.sliding_window,
|
|
||||||
)
|
|
||||||
|
|
||||||
# embed positions
|
# embed positions
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue