Enable kv cache on arc batch (#10308)
This commit is contained in:
		
							parent
							
								
									5809a3f5fe
								
							
						
					
					
						commit
						df2b84f7de
					
				
					 1 changed files with 6 additions and 1 deletions
				
			
		| 
						 | 
					@ -73,11 +73,16 @@ def use_quantize_kv_cache(linear: torch.nn.Module, x: torch.Tensor) -> bool:
 | 
				
			||||||
    if os.environ.get("BIGDL_QUANTIZE_KV_CACHE", None) is not None:
 | 
					    if os.environ.get("BIGDL_QUANTIZE_KV_CACHE", None) is not None:
 | 
				
			||||||
        return os.environ["BIGDL_QUANTIZE_KV_CACHE"] == "1"
 | 
					        return os.environ["BIGDL_QUANTIZE_KV_CACHE"] == "1"
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        return x.device.type == 'xpu' and get_xpu_device_type(x) == "mtl" \
 | 
					        return x.device.type == 'xpu' and kv_cache_device_check(x) \
 | 
				
			||||||
            and hasattr(linear, "qtype") and \
 | 
					            and hasattr(linear, "qtype") and \
 | 
				
			||||||
            linear.qtype != ggml_tensor_qtype["fp16"] and linear.qtype != ggml_tensor_qtype["bf16"]
 | 
					            linear.qtype != ggml_tensor_qtype["fp16"] and linear.qtype != ggml_tensor_qtype["bf16"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def kv_cache_device_check(x: torch.Tensor) -> bool:
 | 
				
			||||||
 | 
					    return get_xpu_device_type(x) == "mtl" or \
 | 
				
			||||||
 | 
					        (get_xpu_device_type(x) == "arc" and 1 < x.size(0) and x.size(0) < 8)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def init_fp8_kv_cache(batch_size, num_heads, current_length, head_dim, device):
 | 
					def init_fp8_kv_cache(batch_size, num_heads, current_length, head_dim, device):
 | 
				
			||||||
    max_length = current_length + FP8_KV_ALLOC_LENGTH
 | 
					    max_length = current_length + FP8_KV_ALLOC_LENGTH
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue