fix lnl perf (#12700)

This commit is contained in:
Yishuo Wang 2025-01-10 18:00:58 +08:00 committed by GitHub
parent 4bf93c66e8
commit db9db51e2c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -47,7 +47,7 @@ import os
import torch import torch
import torch.distributed import torch.distributed
import torch.nn.functional as F import torch.nn.functional as F
from torch import Tensor, device, dtype, nn from torch import Tensor, dtype, nn
from operator import mul from operator import mul
from functools import reduce from functools import reduce
from ipex_llm.transformers.xpu_customize_fwd import custom_fwd, custom_bwd from ipex_llm.transformers.xpu_customize_fwd import custom_fwd, custom_bwd
@ -294,10 +294,10 @@ def use_batch_forward(x: torch.Tensor, qtype: int, output_len: int):
if hard_condition: if hard_condition:
return ( return (
batch_size > 1 batch_size > 1
or (device in ["arc"] and qtype in [SYM_INT8, FP4]) or (device_name in ["arc"] and qtype in [SYM_INT8, FP4])
or (device in ["arc", "mtl"] and qtype in [FP8E4]) or (device_name in ["arc", "mtl"] and qtype in [FP8E4])
or (device in ["lnl"] and qtype in [SYM_INT4] and x.shape[1] % 512 == 0) or (device_name in ["lnl"] and qtype in [SYM_INT4] and x.shape[1] % 512 == 0)
or (device in ["bmg"] and qtype in [SYM_INT4, FP8E5]) or (device_name in ["bmg"] and qtype in [SYM_INT4, FP8E5])
) )
return False return False