Optimize with new batch kernel when batch_size=1 on LNL (#12419)
* Add use batch kernel condition for LNL * Fix for other device judgement * Fix based on comment
This commit is contained in:
parent
7e0a840f74
commit
8fdc36c140
4 changed files with 11 additions and 7 deletions
|
|
@ -167,7 +167,7 @@ class PromptLookupCandidateGenerator():
|
|||
self.num_output_tokens = num_output_tokens
|
||||
self.max_matching_ngram_size = max_matching_ngram_size if max_matching_ngram_size else 2
|
||||
|
||||
if device == "mtl":
|
||||
if device in ["mtl", "lnl"]:
|
||||
self.max_candidates = 3
|
||||
self.min_candidates = 0
|
||||
else:
|
||||
|
|
@ -418,7 +418,7 @@ def lookup_generate(self,
|
|||
accept_rate = self.n_matched/self.n_drafted if self.n_drafted > 0 else 1
|
||||
self.accept_rate.append(accept_rate)
|
||||
# Update the candidate generation strategy if needed
|
||||
if device_name != 'mtl':
|
||||
if device_name not in ["mtl", "lnl"]:
|
||||
candidates_generator.update_candidate_strategy(candidate_length, n_matches,
|
||||
accept_rate)
|
||||
|
||||
|
|
|
|||
|
|
@ -389,6 +389,7 @@ def use_batch_forward(x: torch.Tensor, qtype: int, output_len: int):
|
|||
batch_size > 1
|
||||
or (device in ["arc", "flex"] and qtype in [SYM_INT8, FP4])
|
||||
or (device in ["arc", "flex", "mtl"] and qtype in [FP8E4])
|
||||
or (device in ["lnl"] and qtype in [SYM_INT4] and x.shape[1] % 512 == 0)
|
||||
)
|
||||
return False
|
||||
|
||||
|
|
|
|||
|
|
@ -92,7 +92,7 @@ def use_quantize_kv_cache(linear: torch.nn.Module, x: torch.Tensor, kv_group: in
|
|||
|
||||
|
||||
def kv_cache_device_check(x: torch.Tensor, kv_group: int) -> bool:
|
||||
return (get_xpu_device_type(x) == "mtl" and kv_group <= 1) or \
|
||||
return (get_xpu_device_type(x) in ["mtl", "lnl"] and kv_group <= 1) or \
|
||||
((get_xpu_device_type(x) == "arc" or get_xpu_device_type(x) == "flex") and
|
||||
1 < x.size(0) and x.size(0) <= 8)
|
||||
|
||||
|
|
@ -348,7 +348,7 @@ def mlp_fusion_check(x, qtype, training):
|
|||
return False
|
||||
if qtype == FP6:
|
||||
device = get_xpu_device_type(x)
|
||||
if device == "mtl":
|
||||
if device in ["mtl", "lnl"]:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
|
@ -395,7 +395,7 @@ def use_fused_layer_norm(x: torch.Tensor, training: bool):
|
|||
return (
|
||||
not training
|
||||
and not x.requires_grad
|
||||
and device in ["arc", "flex", "pvc", "mtl"] # fused layer norm cannot run on UHD
|
||||
and device in ["arc", "flex", "pvc", "mtl", "lnl"] # fused layer norm cannot run on UHD
|
||||
and x.numel() // x.size(-1) == 1 # fused layer norm is slower in first token
|
||||
)
|
||||
|
||||
|
|
@ -474,7 +474,7 @@ def should_use_compresskv(x: torch.Tensor, prompt_len: int):
|
|||
else:
|
||||
if use_compress_kv is None:
|
||||
return (
|
||||
get_xpu_device_type(x) == "mtl"
|
||||
get_xpu_device_type(x) in ["mtl", "lnl"]
|
||||
and prompt_len >= 1800
|
||||
and prompt_len <= 4500
|
||||
)
|
||||
|
|
|
|||
|
|
@ -175,6 +175,9 @@ def get_xpu_device_type(x):
|
|||
if name.startswith("Intel(R) Arc(TM) A"):
|
||||
return "arc"
|
||||
elif name.startswith("Intel(R) Arc(TM)"):
|
||||
if 'V' in name:
|
||||
return "lnl"
|
||||
else:
|
||||
return "mtl"
|
||||
elif name.startswith("Intel(R) Data Center GPU Flex"):
|
||||
return "flex"
|
||||
|
|
|
|||
Loading…
Reference in a new issue