[NPU] Fix minicpm on MTL (#12599)

This commit is contained in:
binbin Deng 2024-12-24 15:37:56 +08:00 committed by GitHub
parent ad2dc965c5
commit 45f8f72a28
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -29,6 +29,7 @@ import math
import numpy as np import numpy as np
from typing import Optional, Any, List from typing import Optional, Any, List
import numpy.typing as npt import numpy.typing as npt
import os
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -492,6 +493,10 @@ class LLMBaseNNFactory(NNFactory):
def apply_rotary_pos_emb(self, *, q, k, cos, sin, position_ids, def apply_rotary_pos_emb(self, *, q, k, cos, sin, position_ids,
num_heads, seq_len, head_dim): num_heads, seq_len, head_dim):
if position_ids is not None: if position_ids is not None:
if os.environ.get("IPEX_LLM_NPU_MTL", "0") == "1" or\
os.environ.get("IPEX_LLM_NPU_ARL", "0") == "1":
position_ids = self.reshape(position_ids, [-1])
else:
position_ids = self.squeeze(position_ids) position_ids = self.squeeze(position_ids)
cos = self.gather(cos, self.convert_to_int32(position_ids), self.constant(1), 0) cos = self.gather(cos, self.convert_to_int32(position_ids), self.constant(1), 0)
sin = self.gather(sin, self.convert_to_int32(position_ids), self.constant(1), 0) sin = self.gather(sin, self.convert_to_int32(position_ids), self.constant(1), 0)