[NPU] Fix minicpm on MTL (#12599)
This commit is contained in:
parent
ad2dc965c5
commit
45f8f72a28
1 changed files with 6 additions and 1 deletions
|
|
@ -29,6 +29,7 @@ import math
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from typing import Optional, Any, List
|
from typing import Optional, Any, List
|
||||||
import numpy.typing as npt
|
import numpy.typing as npt
|
||||||
|
import os
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
@ -492,6 +493,10 @@ class LLMBaseNNFactory(NNFactory):
|
||||||
def apply_rotary_pos_emb(self, *, q, k, cos, sin, position_ids,
|
def apply_rotary_pos_emb(self, *, q, k, cos, sin, position_ids,
|
||||||
num_heads, seq_len, head_dim):
|
num_heads, seq_len, head_dim):
|
||||||
if position_ids is not None:
|
if position_ids is not None:
|
||||||
|
if os.environ.get("IPEX_LLM_NPU_MTL", "0") == "1" or\
|
||||||
|
os.environ.get("IPEX_LLM_NPU_ARL", "0") == "1":
|
||||||
|
position_ids = self.reshape(position_ids, [-1])
|
||||||
|
else:
|
||||||
position_ids = self.squeeze(position_ids)
|
position_ids = self.squeeze(position_ids)
|
||||||
cos = self.gather(cos, self.convert_to_int32(position_ids), self.constant(1), 0)
|
cos = self.gather(cos, self.convert_to_int32(position_ids), self.constant(1), 0)
|
||||||
sin = self.gather(sin, self.convert_to_int32(position_ids), self.constant(1), 0)
|
sin = self.gather(sin, self.convert_to_int32(position_ids), self.constant(1), 0)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue