Optimize with new batch kernel when batch_size=1 on LNL (#12419)
* Add use batch kernel condition for LNL * Fix for other device judgement * Fix based on comment
This commit is contained in:
parent
7e0a840f74
commit
8fdc36c140
4 changed files with 11 additions and 7 deletions
|
|
@ -167,7 +167,7 @@ class PromptLookupCandidateGenerator():
|
||||||
self.num_output_tokens = num_output_tokens
|
self.num_output_tokens = num_output_tokens
|
||||||
self.max_matching_ngram_size = max_matching_ngram_size if max_matching_ngram_size else 2
|
self.max_matching_ngram_size = max_matching_ngram_size if max_matching_ngram_size else 2
|
||||||
|
|
||||||
if device == "mtl":
|
if device in ["mtl", "lnl"]:
|
||||||
self.max_candidates = 3
|
self.max_candidates = 3
|
||||||
self.min_candidates = 0
|
self.min_candidates = 0
|
||||||
else:
|
else:
|
||||||
|
|
@ -418,7 +418,7 @@ def lookup_generate(self,
|
||||||
accept_rate = self.n_matched/self.n_drafted if self.n_drafted > 0 else 1
|
accept_rate = self.n_matched/self.n_drafted if self.n_drafted > 0 else 1
|
||||||
self.accept_rate.append(accept_rate)
|
self.accept_rate.append(accept_rate)
|
||||||
# Update the candidate generation strategy if needed
|
# Update the candidate generation strategy if needed
|
||||||
if device_name != 'mtl':
|
if device_name not in ["mtl", "lnl"]:
|
||||||
candidates_generator.update_candidate_strategy(candidate_length, n_matches,
|
candidates_generator.update_candidate_strategy(candidate_length, n_matches,
|
||||||
accept_rate)
|
accept_rate)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -389,6 +389,7 @@ def use_batch_forward(x: torch.Tensor, qtype: int, output_len: int):
|
||||||
batch_size > 1
|
batch_size > 1
|
||||||
or (device in ["arc", "flex"] and qtype in [SYM_INT8, FP4])
|
or (device in ["arc", "flex"] and qtype in [SYM_INT8, FP4])
|
||||||
or (device in ["arc", "flex", "mtl"] and qtype in [FP8E4])
|
or (device in ["arc", "flex", "mtl"] and qtype in [FP8E4])
|
||||||
|
or (device in ["lnl"] and qtype in [SYM_INT4] and x.shape[1] % 512 == 0)
|
||||||
)
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -92,7 +92,7 @@ def use_quantize_kv_cache(linear: torch.nn.Module, x: torch.Tensor, kv_group: in
|
||||||
|
|
||||||
|
|
||||||
def kv_cache_device_check(x: torch.Tensor, kv_group: int) -> bool:
|
def kv_cache_device_check(x: torch.Tensor, kv_group: int) -> bool:
|
||||||
return (get_xpu_device_type(x) == "mtl" and kv_group <= 1) or \
|
return (get_xpu_device_type(x) in ["mtl", "lnl"] and kv_group <= 1) or \
|
||||||
((get_xpu_device_type(x) == "arc" or get_xpu_device_type(x) == "flex") and
|
((get_xpu_device_type(x) == "arc" or get_xpu_device_type(x) == "flex") and
|
||||||
1 < x.size(0) and x.size(0) <= 8)
|
1 < x.size(0) and x.size(0) <= 8)
|
||||||
|
|
||||||
|
|
@ -348,7 +348,7 @@ def mlp_fusion_check(x, qtype, training):
|
||||||
return False
|
return False
|
||||||
if qtype == FP6:
|
if qtype == FP6:
|
||||||
device = get_xpu_device_type(x)
|
device = get_xpu_device_type(x)
|
||||||
if device == "mtl":
|
if device in ["mtl", "lnl"]:
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
@ -395,7 +395,7 @@ def use_fused_layer_norm(x: torch.Tensor, training: bool):
|
||||||
return (
|
return (
|
||||||
not training
|
not training
|
||||||
and not x.requires_grad
|
and not x.requires_grad
|
||||||
and device in ["arc", "flex", "pvc", "mtl"] # fused layer norm cannot run on UHD
|
and device in ["arc", "flex", "pvc", "mtl", "lnl"] # fused layer norm cannot run on UHD
|
||||||
and x.numel() // x.size(-1) == 1 # fused layer norm is slower in first token
|
and x.numel() // x.size(-1) == 1 # fused layer norm is slower in first token
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -474,7 +474,7 @@ def should_use_compresskv(x: torch.Tensor, prompt_len: int):
|
||||||
else:
|
else:
|
||||||
if use_compress_kv is None:
|
if use_compress_kv is None:
|
||||||
return (
|
return (
|
||||||
get_xpu_device_type(x) == "mtl"
|
get_xpu_device_type(x) in ["mtl", "lnl"]
|
||||||
and prompt_len >= 1800
|
and prompt_len >= 1800
|
||||||
and prompt_len <= 4500
|
and prompt_len <= 4500
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -175,7 +175,10 @@ def get_xpu_device_type(x):
|
||||||
if name.startswith("Intel(R) Arc(TM) A"):
|
if name.startswith("Intel(R) Arc(TM) A"):
|
||||||
return "arc"
|
return "arc"
|
||||||
elif name.startswith("Intel(R) Arc(TM)"):
|
elif name.startswith("Intel(R) Arc(TM)"):
|
||||||
return "mtl"
|
if 'V' in name:
|
||||||
|
return "lnl"
|
||||||
|
else:
|
||||||
|
return "mtl"
|
||||||
elif name.startswith("Intel(R) Data Center GPU Flex"):
|
elif name.startswith("Intel(R) Data Center GPU Flex"):
|
||||||
return "flex"
|
return "flex"
|
||||||
elif name.startswith("Intel(R) Data Center GPU Max"):
|
elif name.startswith("Intel(R) Data Center GPU Max"):
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue