Enable kv cache quantization by default for flex when 1 < batch <= 8 (#10584)
* Enable kv cache quantization by default for flex when 1 < batch <= 8. * Change up bound from <8 to <=8.
This commit is contained in:
		
							parent
							
								
									b44f7adbad
								
							
						
					
					
						commit
						f4537798c1
					
				
					 1 changed files with 3 additions and 2 deletions
				
			
		| 
						 | 
				
			
			@ -73,7 +73,7 @@ def append_kv_cache(cache_k, cache_v, key_states, value_states):
 | 
			
		|||
 | 
			
		||||
def use_quantize_kv_cache(linear: torch.nn.Module, x: torch.Tensor) -> bool:
 | 
			
		||||
    if os.environ.get("BIGDL_QUANTIZE_KV_CACHE", None) is not None:
 | 
			
		||||
        return os.environ["BIGDL_QUANTIZE_KV_CACHE"] == "1"
 | 
			
		||||
        return int(os.environ["BIGDL_QUANTIZE_KV_CACHE"]) == 1
 | 
			
		||||
    else:
 | 
			
		||||
        return x.device.type == 'xpu' and kv_cache_device_check(x) \
 | 
			
		||||
            and hasattr(linear, "qtype") and \
 | 
			
		||||
| 
						 | 
				
			
			@ -82,7 +82,8 @@ def use_quantize_kv_cache(linear: torch.nn.Module, x: torch.Tensor) -> bool:
 | 
			
		|||
 | 
			
		||||
def kv_cache_device_check(x: torch.Tensor) -> bool:
 | 
			
		||||
    return get_xpu_device_type(x) == "mtl" or \
 | 
			
		||||
        (get_xpu_device_type(x) == "arc" and 1 < x.size(0) and x.size(0) < 8)
 | 
			
		||||
        ((get_xpu_device_type(x) == "arc" or get_xpu_device_type(x) == "flex") and \
 | 
			
		||||
            1 < x.size(0) and x.size(0) <= 8)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def init_fp8_kv_cache(batch_size, num_heads, current_length, head_dim, device, new_layout=False):
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue