fix qwen2 cpu (#11663)
This commit is contained in:
parent
23681fbf5c
commit
6bcdc6cc8f
1 changed files with 1 additions and 1 deletions
|
|
@ -507,7 +507,7 @@ def qwen2_mlp_forward(
|
||||||
x_2d.shape[0], x_2d.shape[1], self.gate_proj.out_len,
|
x_2d.shape[0], x_2d.shape[1], self.gate_proj.out_len,
|
||||||
SILU, qtype
|
SILU, qtype
|
||||||
))
|
))
|
||||||
elif not self.training:
|
elif x.device.type == "xpu" and not self.training:
|
||||||
import xe_addons
|
import xe_addons
|
||||||
gate = self.gate_proj(x)
|
gate = self.gate_proj(x)
|
||||||
up = self.up_proj(x)
|
up = self.up_proj(x)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue