only use quantize kv cache on MTL (#9862)
This commit is contained in:
		
							parent
							
								
									146076bdb5
								
							
						
					
					
						commit
						36496d60ac
					
				
					 2 changed files with 5 additions and 2 deletions
				
			
		| 
						 | 
					@ -18,7 +18,7 @@ import os
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
from bigdl.llm.utils.common import invalidInputError
 | 
					from bigdl.llm.utils.common import invalidInputError
 | 
				
			||||||
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
 | 
					from bigdl.llm.ggml.quantize import ggml_tensor_qtype
 | 
				
			||||||
from bigdl.llm.transformers.utils import get_ipex_version
 | 
					from bigdl.llm.transformers.utils import get_ipex_version, get_xpu_device_type
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def init_kv_cache(batch_size, num_heads, head_dim, current_length, max_length, dtype, device):
 | 
					def init_kv_cache(batch_size, num_heads, head_dim, current_length, max_length, dtype, device):
 | 
				
			||||||
| 
						 | 
					@ -63,7 +63,8 @@ def quantize_kv_cache(linear: torch.nn.Module, x: torch.Tensor) -> bool:
 | 
				
			||||||
    if os.environ.get("BIGDL_QUANTIZE_KV_CACHE", None) is not None:
 | 
					    if os.environ.get("BIGDL_QUANTIZE_KV_CACHE", None) is not None:
 | 
				
			||||||
        return os.environ["BIGDL_QUANTIZE_KV_CACHE"] == "1"
 | 
					        return os.environ["BIGDL_QUANTIZE_KV_CACHE"] == "1"
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        return x.device.type == 'xpu' and hasattr(linear, "qtype") and \
 | 
					        return x.device.type == 'xpu' and get_xpu_device_type(x) == "mtl" \
 | 
				
			||||||
 | 
					            and hasattr(linear, "qtype") and \
 | 
				
			||||||
            linear.qtype != ggml_tensor_qtype["fp16"] and linear.qtype != ggml_tensor_qtype["bf16"]
 | 
					            linear.qtype != ggml_tensor_qtype["fp16"] and linear.qtype != ggml_tensor_qtype["bf16"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -169,6 +169,8 @@ def get_xpu_device_type(x):
 | 
				
			||||||
    name = torch.xpu.get_device_name(x.device.index)
 | 
					    name = torch.xpu.get_device_name(x.device.index)
 | 
				
			||||||
    if name.startswith("Intel(R) Arc(TM) A"):
 | 
					    if name.startswith("Intel(R) Arc(TM) A"):
 | 
				
			||||||
        return "arc"
 | 
					        return "arc"
 | 
				
			||||||
 | 
					    elif name.startswith("Intel(R) Arc(TM)"):
 | 
				
			||||||
 | 
					        return "mtl"
 | 
				
			||||||
    elif name.startswith("Intel(R) Data Center GPU Flex"):
 | 
					    elif name.startswith("Intel(R) Data Center GPU Flex"):
 | 
				
			||||||
        return "flex"
 | 
					        return "flex"
 | 
				
			||||||
    elif name.startswith("Intel(R) Data Center GPU Max"):
 | 
					    elif name.startswith("Intel(R) Data Center GPU Max"):
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue