[NPU] Fix minicpm on MTL (#12599)

This commit is contained in:
binbin Deng 2024-12-24 15:37:56 +08:00 committed by GitHub
parent ad2dc965c5
commit 45f8f72a28
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -29,6 +29,7 @@ import math
import numpy as np
from typing import Optional, Any, List
import numpy.typing as npt
import os
logger = logging.get_logger(__name__)
@ -492,6 +493,10 @@ class LLMBaseNNFactory(NNFactory):
def apply_rotary_pos_emb(self, *, q, k, cos, sin, position_ids,
num_heads, seq_len, head_dim):
if position_ids is not None:
if os.environ.get("IPEX_LLM_NPU_MTL", "0") == "1" or\
os.environ.get("IPEX_LLM_NPU_ARL", "0") == "1":
position_ids = self.reshape(position_ids, [-1])
else:
position_ids = self.squeeze(position_ids)
cos = self.gather(cos, self.convert_to_int32(position_ids), self.constant(1), 0)
sin = self.gather(sin, self.convert_to_int32(position_ids), self.constant(1), 0)