fix qwen2 cpu (#11663)

This commit is contained in:
Yishuo Wang 2024-07-26 13:41:51 +08:00 committed by GitHub
parent 23681fbf5c
commit 6bcdc6cc8f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -507,7 +507,7 @@ def qwen2_mlp_forward(
x_2d.shape[0], x_2d.shape[1], self.gate_proj.out_len, x_2d.shape[0], x_2d.shape[1], self.gate_proj.out_len,
SILU, qtype SILU, qtype
)) ))
elif not self.training: elif x.device.type == "xpu" and not self.training:
import xe_addons import xe_addons
gate = self.gate_proj(x) gate = self.gate_proj(x)
up = self.up_proj(x) up = self.up_proj(x)