[NPU] Fix minicpm on MTL (#12599)
This commit is contained in:
		
							parent
							
								
									ad2dc965c5
								
							
						
					
					
						commit
						45f8f72a28
					
				
					 1 changed files with 6 additions and 1 deletions
				
			
		| 
						 | 
				
			
			@ -29,6 +29,7 @@ import math
 | 
			
		|||
import numpy as np
 | 
			
		||||
from typing import Optional, Any, List
 | 
			
		||||
import numpy.typing as npt
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
logger = logging.get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -492,7 +493,11 @@ class LLMBaseNNFactory(NNFactory):
 | 
			
		|||
    def apply_rotary_pos_emb(self, *, q, k, cos, sin, position_ids,
 | 
			
		||||
                             num_heads, seq_len, head_dim):
 | 
			
		||||
        if position_ids is not None:
 | 
			
		||||
            position_ids = self.squeeze(position_ids)
 | 
			
		||||
            if os.environ.get("IPEX_LLM_NPU_MTL", "0") == "1" or\
 | 
			
		||||
               os.environ.get("IPEX_LLM_NPU_ARL", "0") == "1":
 | 
			
		||||
                position_ids = self.reshape(position_ids, [-1])
 | 
			
		||||
            else:
 | 
			
		||||
                position_ids = self.squeeze(position_ids)
 | 
			
		||||
            cos = self.gather(cos, self.convert_to_int32(position_ids), self.constant(1), 0)
 | 
			
		||||
            sin = self.gather(sin, self.convert_to_int32(position_ids), self.constant(1), 0)
 | 
			
		||||
            cos = self.unsqueeze(cos, [1])
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue