Optimize with new batch kernel when batch_size=1 on LNL (#12419)

* Add use batch kernel condition for LNL

* Fix for other device judgement

* Fix based on comment
This commit is contained in:
Yuwen Hu 2024-11-21 16:21:35 +08:00 committed by GitHub
parent 7e0a840f74
commit 8fdc36c140
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 11 additions and 7 deletions

View file

@ -167,7 +167,7 @@ class PromptLookupCandidateGenerator():
self.num_output_tokens = num_output_tokens
self.max_matching_ngram_size = max_matching_ngram_size if max_matching_ngram_size else 2
if device == "mtl":
if device in ["mtl", "lnl"]:
self.max_candidates = 3
self.min_candidates = 0
else:
@ -418,7 +418,7 @@ def lookup_generate(self,
accept_rate = self.n_matched/self.n_drafted if self.n_drafted > 0 else 1
self.accept_rate.append(accept_rate)
# Update the candidate generation strategy if needed
if device_name != 'mtl':
if device_name not in ["mtl", "lnl"]:
candidates_generator.update_candidate_strategy(candidate_length, n_matches,
accept_rate)

View file

@ -389,6 +389,7 @@ def use_batch_forward(x: torch.Tensor, qtype: int, output_len: int):
batch_size > 1
or (device in ["arc", "flex"] and qtype in [SYM_INT8, FP4])
or (device in ["arc", "flex", "mtl"] and qtype in [FP8E4])
or (device in ["lnl"] and qtype in [SYM_INT4] and x.shape[1] % 512 == 0)
)
return False

View file

@ -92,7 +92,7 @@ def use_quantize_kv_cache(linear: torch.nn.Module, x: torch.Tensor, kv_group: in
def kv_cache_device_check(x: torch.Tensor, kv_group: int) -> bool:
return (get_xpu_device_type(x) == "mtl" and kv_group <= 1) or \
return (get_xpu_device_type(x) in ["mtl", "lnl"] and kv_group <= 1) or \
((get_xpu_device_type(x) == "arc" or get_xpu_device_type(x) == "flex") and
1 < x.size(0) and x.size(0) <= 8)
@ -348,7 +348,7 @@ def mlp_fusion_check(x, qtype, training):
return False
if qtype == FP6:
device = get_xpu_device_type(x)
if device == "mtl":
if device in ["mtl", "lnl"]:
return False
return True
@ -395,7 +395,7 @@ def use_fused_layer_norm(x: torch.Tensor, training: bool):
return (
not training
and not x.requires_grad
and device in ["arc", "flex", "pvc", "mtl"] # fused layer norm cannot run on UHD
and device in ["arc", "flex", "pvc", "mtl", "lnl"] # fused layer norm cannot run on UHD
and x.numel() // x.size(-1) == 1 # fused layer norm is slower in first token
)
@ -474,7 +474,7 @@ def should_use_compresskv(x: torch.Tensor, prompt_len: int):
else:
if use_compress_kv is None:
return (
get_xpu_device_type(x) == "mtl"
get_xpu_device_type(x) in ["mtl", "lnl"]
and prompt_len >= 1800
and prompt_len <= 4500
)

View file

@ -175,7 +175,10 @@ def get_xpu_device_type(x):
if name.startswith("Intel(R) Arc(TM) A"):
return "arc"
elif name.startswith("Intel(R) Arc(TM)"):
return "mtl"
if 'V' in name:
return "lnl"
else:
return "mtl"
elif name.startswith("Intel(R) Data Center GPU Flex"):
return "flex"
elif name.startswith("Intel(R) Data Center GPU Max"):