fix qwen2 cpu (#11663)
This commit is contained in:
		
							parent
							
								
									23681fbf5c
								
							
						
					
					
						commit
						6bcdc6cc8f
					
				
					 1 changed files with 1 additions and 1 deletions
				
			
		| 
						 | 
					@ -507,7 +507,7 @@ def qwen2_mlp_forward(
 | 
				
			||||||
            x_2d.shape[0], x_2d.shape[1], self.gate_proj.out_len,
 | 
					            x_2d.shape[0], x_2d.shape[1], self.gate_proj.out_len,
 | 
				
			||||||
            SILU, qtype
 | 
					            SILU, qtype
 | 
				
			||||||
        ))
 | 
					        ))
 | 
				
			||||||
    elif not self.training:
 | 
					    elif x.device.type == "xpu" and not self.training:
 | 
				
			||||||
        import xe_addons
 | 
					        import xe_addons
 | 
				
			||||||
        gate = self.gate_proj(x)
 | 
					        gate = self.gate_proj(x)
 | 
				
			||||||
        up = self.up_proj(x)
 | 
					        up = self.up_proj(x)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue