Optimize with new batch kernel when batch_size=1 on LNL (#12419)
				
					
				
			* Add use batch kernel condition for LNL * Fix for other device judgement * Fix based on comment
This commit is contained in:
		
							parent
							
								
									7e0a840f74
								
							
						
					
					
						commit
						8fdc36c140
					
				
					 4 changed files with 11 additions and 7 deletions
				
			
		| 
						 | 
				
			
			@ -167,7 +167,7 @@ class PromptLookupCandidateGenerator():
 | 
			
		|||
        self.num_output_tokens = num_output_tokens
 | 
			
		||||
        self.max_matching_ngram_size = max_matching_ngram_size if max_matching_ngram_size else 2
 | 
			
		||||
 | 
			
		||||
        if device == "mtl":
 | 
			
		||||
        if device in ["mtl", "lnl"]:
 | 
			
		||||
            self.max_candidates = 3
 | 
			
		||||
            self.min_candidates = 0
 | 
			
		||||
        else:
 | 
			
		||||
| 
						 | 
				
			
			@ -418,7 +418,7 @@ def lookup_generate(self,
 | 
			
		|||
            accept_rate = self.n_matched/self.n_drafted if self.n_drafted > 0 else 1
 | 
			
		||||
            self.accept_rate.append(accept_rate)
 | 
			
		||||
            # Update the candidate generation strategy if needed
 | 
			
		||||
            if device_name != 'mtl':
 | 
			
		||||
            if device_name not in ["mtl", "lnl"]:
 | 
			
		||||
                candidates_generator.update_candidate_strategy(candidate_length, n_matches,
 | 
			
		||||
                                                               accept_rate)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -389,6 +389,7 @@ def use_batch_forward(x: torch.Tensor, qtype: int, output_len: int):
 | 
			
		|||
            batch_size > 1
 | 
			
		||||
            or (device in ["arc", "flex"] and qtype in [SYM_INT8, FP4])
 | 
			
		||||
            or (device in ["arc", "flex", "mtl"] and qtype in [FP8E4])
 | 
			
		||||
            or (device in ["lnl"] and qtype in [SYM_INT4] and x.shape[1] % 512 == 0)
 | 
			
		||||
        )
 | 
			
		||||
    return False
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -92,7 +92,7 @@ def use_quantize_kv_cache(linear: torch.nn.Module, x: torch.Tensor, kv_group: in
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
def kv_cache_device_check(x: torch.Tensor, kv_group: int) -> bool:
 | 
			
		||||
    return (get_xpu_device_type(x) == "mtl" and kv_group <= 1) or \
 | 
			
		||||
    return (get_xpu_device_type(x) in ["mtl", "lnl"] and kv_group <= 1) or \
 | 
			
		||||
        ((get_xpu_device_type(x) == "arc" or get_xpu_device_type(x) == "flex") and
 | 
			
		||||
            1 < x.size(0) and x.size(0) <= 8)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -348,7 +348,7 @@ def mlp_fusion_check(x, qtype, training):
 | 
			
		|||
        return False
 | 
			
		||||
    if qtype == FP6:
 | 
			
		||||
        device = get_xpu_device_type(x)
 | 
			
		||||
        if device == "mtl":
 | 
			
		||||
        if device in ["mtl", "lnl"]:
 | 
			
		||||
            return False
 | 
			
		||||
    return True
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -395,7 +395,7 @@ def use_fused_layer_norm(x: torch.Tensor, training: bool):
 | 
			
		|||
    return (
 | 
			
		||||
        not training
 | 
			
		||||
        and not x.requires_grad
 | 
			
		||||
        and device in ["arc", "flex", "pvc", "mtl"]  # fused layer norm cannot run on UHD
 | 
			
		||||
        and device in ["arc", "flex", "pvc", "mtl", "lnl"]  # fused layer norm cannot run on UHD
 | 
			
		||||
        and x.numel() // x.size(-1) == 1  # fused layer norm is slower in first token
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -474,7 +474,7 @@ def should_use_compresskv(x: torch.Tensor, prompt_len: int):
 | 
			
		|||
    else:
 | 
			
		||||
        if use_compress_kv is None:
 | 
			
		||||
            return (
 | 
			
		||||
                get_xpu_device_type(x) == "mtl"
 | 
			
		||||
                get_xpu_device_type(x) in ["mtl", "lnl"]
 | 
			
		||||
                and prompt_len >= 1800
 | 
			
		||||
                and prompt_len <= 4500
 | 
			
		||||
            )
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -175,7 +175,10 @@ def get_xpu_device_type(x):
 | 
			
		|||
    if name.startswith("Intel(R) Arc(TM) A"):
 | 
			
		||||
        return "arc"
 | 
			
		||||
    elif name.startswith("Intel(R) Arc(TM)"):
 | 
			
		||||
        return "mtl"
 | 
			
		||||
        if 'V' in name:
 | 
			
		||||
            return "lnl"
 | 
			
		||||
        else:
 | 
			
		||||
            return "mtl"
 | 
			
		||||
    elif name.startswith("Intel(R) Data Center GPU Flex"):
 | 
			
		||||
        return "flex"
 | 
			
		||||
    elif name.startswith("Intel(R) Data Center GPU Max"):
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue