change xmx condition (#9896)
This commit is contained in:
		
							parent
							
								
									0e69bfe6b0
								
							
						
					
					
						commit
						6637860ddf
					
				
					 3 changed files with 18 additions and 53 deletions
				
			
		| 
						 | 
				
			
			@ -470,7 +470,7 @@ class LowBitLinear(nn.Linear):
 | 
			
		|||
            try:
 | 
			
		||||
                import intel_extension_for_pytorch
 | 
			
		||||
                import linear_q4_0
 | 
			
		||||
                from bigdl.llm.utils.xmx_checker import use_xmx
 | 
			
		||||
                from bigdl.llm.transformers.models.utils import use_xmx
 | 
			
		||||
            except ModuleNotFoundError:
 | 
			
		||||
                invalidInputError(False,
 | 
			
		||||
                                  "Please `pip install bigdl_core_xe` first.")
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -21,6 +21,11 @@ from bigdl.llm.ggml.quantize import ggml_tensor_qtype
 | 
			
		|||
from bigdl.llm.transformers.utils import get_ipex_version, get_xpu_device_type
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
SYM_INT4 = ggml_tensor_qtype["sym_int4"]
 | 
			
		||||
SYM_INT8 = ggml_tensor_qtype["sym_int8"]
 | 
			
		||||
FP8 = ggml_tensor_qtype["fp8"]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def init_kv_cache(batch_size, num_heads, head_dim, current_length, max_length, dtype, device):
 | 
			
		||||
    key_cache_storage = torch.empty(batch_size, num_heads,
 | 
			
		||||
                                    max_length, head_dim,
 | 
			
		||||
| 
						 | 
				
			
			@ -263,3 +268,15 @@ def mlp_fusion_check(x, qtype, training):
 | 
			
		|||
    if training or x.requires_grad:
 | 
			
		||||
        return False
 | 
			
		||||
    return True
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def use_xmx(x: torch.Tensor, qtype: int):
 | 
			
		||||
    device = get_xpu_device_type(x)
 | 
			
		||||
    return (
 | 
			
		||||
        device in ["arc", "flex", "pvc"]
 | 
			
		||||
        and qtype in [SYM_INT4, SYM_INT8, FP8]
 | 
			
		||||
        and (
 | 
			
		||||
            (device != "pvc" and x.dtype == torch.float32 and 1 < x.size(0) <= 64)
 | 
			
		||||
            or 1 < x.size(0) <= 8
 | 
			
		||||
        )
 | 
			
		||||
    )
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,52 +0,0 @@
 | 
			
		|||
#
 | 
			
		||||
# Copyright 2016 The BigDL Authors.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
#
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import intel_extension_for_pytorch as ipex
 | 
			
		||||
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
SYM_INT4 = ggml_tensor_qtype["sym_int4"]
 | 
			
		||||
SYM_INT8 = ggml_tensor_qtype["sym_int8"]
 | 
			
		||||
NF4 = ggml_tensor_qtype["nf4"]
 | 
			
		||||
NF3 = ggml_tensor_qtype["nf3"]
 | 
			
		||||
FP8 = ggml_tensor_qtype["fp8"]
 | 
			
		||||
FP4 = ggml_tensor_qtype["fp4"]
 | 
			
		||||
MOFQ4 = ggml_tensor_qtype["mixed_fp4"]
 | 
			
		||||
MOFQ8 = ggml_tensor_qtype["mixed_fp8"]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class XMXChecker:
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        self.support_xmx = self.check_xmx()
 | 
			
		||||
        self.supported_qtype = [SYM_INT4, SYM_INT8, FP8]
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def check_xmx():
 | 
			
		||||
        name = torch.xpu.get_device_name(0)
 | 
			
		||||
        # todo: not sure how to check xmx or how to get device name for now
 | 
			
		||||
        return "Arc(TM)" in name or "GPU Max" in name or "GPU Flex" in name
 | 
			
		||||
 | 
			
		||||
    def check(self, input_tensor: torch.Tensor, qtype: int):
 | 
			
		||||
        return self.support_xmx and 1 < input_tensor.shape[0] <= 8 and \
 | 
			
		||||
            qtype in self.supported_qtype
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
xmx_checker = XMXChecker()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def use_xmx(input_tensor: torch.Tensor, qtype: int):
 | 
			
		||||
    return xmx_checker.check(input_tensor, qtype)
 | 
			
		||||
		Loading…
	
		Reference in a new issue