LLM: better FP16 support for Intel GPUs (#9791)
* initial support * fix * fix style * fix * limi esimd usage condition * refactor code * fix style * small fix * meet code review * small fix
This commit is contained in:
		
							parent
							
								
									7d9f6c6efc
								
							
						
					
					
						commit
						99bddd3ab4
					
				
					 4 changed files with 153 additions and 75 deletions
				
			
		| 
						 | 
					@ -200,8 +200,8 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
				
			||||||
                            bias=has_bias,
 | 
					                            bias=has_bias,
 | 
				
			||||||
                            mp_group=mp_group,
 | 
					                            mp_group=mp_group,
 | 
				
			||||||
                        )
 | 
					                        )
 | 
				
			||||||
                        device_type = module.qweight.data.device.type
 | 
					                        device = module.qweight.data.device
 | 
				
			||||||
                        invalidInputError(device_type != "meta",
 | 
					                        invalidInputError(device.type != "meta",
 | 
				
			||||||
                                          "converting from meta device is not supported")
 | 
					                                          "converting from meta device is not supported")
 | 
				
			||||||
                        # Copy the weights
 | 
					                        # Copy the weights
 | 
				
			||||||
                        paramsLowBit = FP4Params(data=convert_gptq(module, awq=is_awq),
 | 
					                        paramsLowBit = FP4Params(data=convert_gptq(module, awq=is_awq),
 | 
				
			||||||
| 
						 | 
					@ -209,11 +209,11 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
				
			||||||
                                                 quantized=True,
 | 
					                                                 quantized=True,
 | 
				
			||||||
                                                 _shape=(out_features, in_features),
 | 
					                                                 _shape=(out_features, in_features),
 | 
				
			||||||
                                                 convert_shape_only=convert_shape_only,
 | 
					                                                 convert_shape_only=convert_shape_only,
 | 
				
			||||||
                                                 qtype=qtype).to(device_type)
 | 
					                                                 qtype=qtype).to(device)
 | 
				
			||||||
                        new_linear._parameters['weight'] = paramsLowBit
 | 
					                        new_linear._parameters['weight'] = paramsLowBit
 | 
				
			||||||
                        if has_bias:
 | 
					                        if has_bias:
 | 
				
			||||||
                            new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
 | 
					                            new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
 | 
				
			||||||
                                .to(device_type)
 | 
					                                .to(device)
 | 
				
			||||||
                    elif qtype not in [ggml_tensor_qtype["fp16"], ggml_tensor_qtype["bf16"]]:
 | 
					                    elif qtype not in [ggml_tensor_qtype["fp16"], ggml_tensor_qtype["bf16"]]:
 | 
				
			||||||
                        new_linear = LowBitLinear(
 | 
					                        new_linear = LowBitLinear(
 | 
				
			||||||
                            in_features,
 | 
					                            in_features,
 | 
				
			||||||
| 
						 | 
					@ -223,44 +223,39 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
				
			||||||
                            mp_group=mp_group,
 | 
					                            mp_group=mp_group,
 | 
				
			||||||
                        )
 | 
					                        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                        device_type = module.weight.data.device.type
 | 
					                        device = module.weight.data.device
 | 
				
			||||||
                        # Copy the weights
 | 
					                        # Copy the weights
 | 
				
			||||||
                        paramsLowBit = FP4Params(data=module.weight.data,
 | 
					                        paramsLowBit = FP4Params(data=module.weight.data,
 | 
				
			||||||
                                                 requires_grad=False,
 | 
					                                                 requires_grad=False,
 | 
				
			||||||
                                                 quantized=False,
 | 
					                                                 quantized=False,
 | 
				
			||||||
                                                 _shape=None,
 | 
					                                                 _shape=None,
 | 
				
			||||||
                                                 convert_shape_only=convert_shape_only,
 | 
					                                                 convert_shape_only=convert_shape_only,
 | 
				
			||||||
                                                 qtype=qtype).to(device_type)
 | 
					                                                 qtype=qtype).to(device)
 | 
				
			||||||
                        new_linear._parameters['weight'] = paramsLowBit
 | 
					                        new_linear._parameters['weight'] = paramsLowBit
 | 
				
			||||||
                        if module.bias is not None:
 | 
					                        if module.bias is not None:
 | 
				
			||||||
                            new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
 | 
					                            new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
 | 
				
			||||||
                                .to(device_type)
 | 
					                                .to(device)
 | 
				
			||||||
                    elif qtype == ggml_tensor_qtype["fp16"]:
 | 
					                    elif qtype == ggml_tensor_qtype["fp16"]:
 | 
				
			||||||
                        #  only support two size now
 | 
					                        module.to(torch.float16)
 | 
				
			||||||
                        #  may generalize to other sizes
 | 
					                        new_linear = FP16Linear(
 | 
				
			||||||
                        if module.in_features in [4096, 11008]:
 | 
					                            in_features,
 | 
				
			||||||
                            # esimd fp16 path
 | 
					                            out_features,
 | 
				
			||||||
                            new_linear = FP16Linear(
 | 
					                            module.bias is not None,
 | 
				
			||||||
                                in_features,
 | 
					                            mp_group=mp_group,
 | 
				
			||||||
                                out_features,
 | 
					                        )
 | 
				
			||||||
                                qtype,
 | 
					                        device = module.weight.data.device
 | 
				
			||||||
                                module.bias is not None,
 | 
					                        from bigdl.llm.transformers.utils import get_ipex_version
 | 
				
			||||||
                                mp_group=mp_group,
 | 
					                        if get_ipex_version() < "2.1.10+xpu":
 | 
				
			||||||
                            )
 | 
					                            new_linear._parameters['weight'] = nn.Parameter(module.weight)
 | 
				
			||||||
                            device_type = module.weight.data.device.type
 | 
					                        else:
 | 
				
			||||||
 | 
					                            # only from 2.1, ipex provides matmul_bias_out
 | 
				
			||||||
                            # convert here
 | 
					                            # so we need to transpose weight
 | 
				
			||||||
                            m, n = module.weight.data.shape
 | 
					                            new_weight = module.weight.transpose(0, 1).contiguous()
 | 
				
			||||||
                            if module.in_features == 11008:
 | 
					                            new_linear._parameters['weight'] = nn.Parameter(new_weight)
 | 
				
			||||||
                                trans_weight = module.weight.data.reshape(m//8, 8, n)
 | 
					                            new_linear.weight_type = 2
 | 
				
			||||||
                                trans_weight = trans_weight.transpose(1, 2).contiguous()
 | 
					                        if module.bias is not None:
 | 
				
			||||||
                            elif module.in_features == 4096:
 | 
					                            new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
 | 
				
			||||||
                                trans_weight = module.weight.data.reshape(m//16, 16, n)
 | 
					                                .to(device)
 | 
				
			||||||
                                trans_weight = trans_weight.transpose(1, 2).contiguous()
 | 
					 | 
				
			||||||
                            new_linear._parameters['weight'] = nn.Parameter(trans_weight)
 | 
					 | 
				
			||||||
                            if module.bias is not None:
 | 
					 | 
				
			||||||
                                new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
 | 
					 | 
				
			||||||
                                    .to(device_type)
 | 
					 | 
				
			||||||
                    elif qtype == ggml_tensor_qtype["bf16"]:
 | 
					                    elif qtype == ggml_tensor_qtype["bf16"]:
 | 
				
			||||||
                        module.to(torch.bfloat16)
 | 
					                        module.to(torch.bfloat16)
 | 
				
			||||||
                        new_linear = BF16Linear(
 | 
					                        new_linear = BF16Linear(
 | 
				
			||||||
| 
						 | 
					@ -269,12 +264,12 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
				
			||||||
                            module.bias is not None,
 | 
					                            module.bias is not None,
 | 
				
			||||||
                            mp_group=mp_group,
 | 
					                            mp_group=mp_group,
 | 
				
			||||||
                        )
 | 
					                        )
 | 
				
			||||||
                        device_type = module.weight.data.device.type
 | 
					                        device = module.weight.data.device
 | 
				
			||||||
                        # convert here
 | 
					                        # convert here
 | 
				
			||||||
                        new_linear._parameters['weight'] = nn.Parameter(module.weight)
 | 
					                        new_linear._parameters['weight'] = nn.Parameter(module.weight)
 | 
				
			||||||
                        if module.bias is not None:
 | 
					                        if module.bias is not None:
 | 
				
			||||||
                            new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
 | 
					                            new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
 | 
				
			||||||
                                .to(device_type)
 | 
					                                .to(device)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                    if new_linear is not None:
 | 
					                    if new_linear is not None:
 | 
				
			||||||
                        if not module.training:
 | 
					                        if not module.training:
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -50,7 +50,8 @@ from torch import Tensor, device, dtype, nn
 | 
				
			||||||
from operator import mul
 | 
					from operator import mul
 | 
				
			||||||
from functools import reduce
 | 
					from functools import reduce
 | 
				
			||||||
from bigdl.llm.transformers.xpu_customize_fwd import custom_fwd, custom_bwd
 | 
					from bigdl.llm.transformers.xpu_customize_fwd import custom_fwd, custom_bwd
 | 
				
			||||||
from bigdl.llm.transformers.utils import get_autocast_dtype
 | 
					from bigdl.llm.transformers.utils import get_autocast_dtype, get_xpu_device_type, \
 | 
				
			||||||
 | 
					    get_ipex_version
 | 
				
			||||||
 | 
					
 | 
				
			||||||
T = TypeVar("T", bound="torch.nn.Module")
 | 
					T = TypeVar("T", bound="torch.nn.Module")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -538,57 +539,111 @@ class LowBitLinear(nn.Linear):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class FP16Linear(nn.Linear):
 | 
					class FP16Linear(nn.Linear):
 | 
				
			||||||
    def __init__(self, input_features, output_features, qtype, bias=True,
 | 
					    def __init__(self, input_features, output_features, bias=True,
 | 
				
			||||||
                 conver_to_half=True, mp_group=None):
 | 
					                 mp_group=None, weight_type=1):
 | 
				
			||||||
        super().__init__(input_features, output_features, bias)
 | 
					        super().__init__(input_features, output_features, bias)
 | 
				
			||||||
        self.in_len = input_features
 | 
					        self.in_len = input_features
 | 
				
			||||||
        self.out_len = output_features
 | 
					        self.out_len = output_features
 | 
				
			||||||
        self.weight_shape = (self.out_len, self.in_len)
 | 
					        self.weight_shape = (self.out_len, self.in_len)
 | 
				
			||||||
        self.weight_length = self.out_len * self.in_len
 | 
					        self.weight_length = self.out_len * self.in_len
 | 
				
			||||||
        self.qtype = qtype
 | 
					        self.qtype = ggml_tensor_qtype["fp16"]
 | 
				
			||||||
        self.conver_to_half = conver_to_half
 | 
					 | 
				
			||||||
        self.mp_group = mp_group
 | 
					        self.mp_group = mp_group
 | 
				
			||||||
 | 
					        # weigh_type = 1 means original weight
 | 
				
			||||||
 | 
					        # weigh_type = 2 means weight has been transposed
 | 
				
			||||||
 | 
					        # weigh_type = 3 means weight has been transposed by esimd method
 | 
				
			||||||
 | 
					        self.weight_type = 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def forward(self, x: torch.Tensor):
 | 
					    def forward(self, x: torch.Tensor):
 | 
				
			||||||
        if self.bias is not None and self.bias.dtype != x.dtype:
 | 
					 | 
				
			||||||
            self.bias.data = self.bias.data.to(x.dtype)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        x_shape = x.shape
 | 
					 | 
				
			||||||
        x_2d = x.view(-1, x_shape[-1])
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        x0 = self.weight.data
 | 
					 | 
				
			||||||
        # only work for GPU
 | 
					        # only work for GPU
 | 
				
			||||||
        invalidInputError(x0.device.type == "xpu",
 | 
					        invalidInputError(x.device.type == "xpu",
 | 
				
			||||||
                          "FP16 only works for GPU")
 | 
					                          "FP16Linear only works for Intel GPUs")
 | 
				
			||||||
        try:
 | 
					        x = x.to(torch.float16)
 | 
				
			||||||
            import intel_extension_for_pytorch
 | 
					        if self.bias is not None and self.bias.dtype != x.dtype:
 | 
				
			||||||
            import linear_fp16_esimd
 | 
					                self.bias.data = self.bias.data.to(x.dtype)
 | 
				
			||||||
        except ModuleNotFoundError:
 | 
					        if self.weight is not None and self.weight.dtype != x.dtype:
 | 
				
			||||||
            invalidInputError(False,
 | 
					            self.weight.data = self.weight.data.to(x.dtype)
 | 
				
			||||||
                              "Please `pip install bigdl_core_xe` first.")
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if x_2d.is_contiguous() is False:
 | 
					        if not self.use_esimd_kernel(x):
 | 
				
			||||||
            x_2d = x_2d.contiguous()
 | 
					            if get_ipex_version() < "2.1.10+xpu":
 | 
				
			||||||
 | 
					                if self.weight_type == 2:
 | 
				
			||||||
        if x_2d.shape[0] > 1:
 | 
					                    self.weight = self.weight.transpose(0, 1).contiguous()
 | 
				
			||||||
            # first token or batch size > 1, re-convert weight
 | 
					                    self.weight_type = 1
 | 
				
			||||||
            original_weight = self.weight.data.transpose(1, 2)
 | 
					                return F.linear(x, self.weight, self.bias)
 | 
				
			||||||
            original_weight = original_weight.reshape(self.out_len, self.in_len)
 | 
					            else:
 | 
				
			||||||
            result = F.linear(x_2d, original_weight.contiguous())
 | 
					                if self.weight_type == 1:
 | 
				
			||||||
            del original_weight
 | 
					                    self.weight = self.weight.transpose(0, 1).contiguous()
 | 
				
			||||||
 | 
					                    self.weight_type = 2
 | 
				
			||||||
 | 
					                return torch.ops.torch_ipex.matmul_bias_out(x, self.weight, self.bias)
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            # rest token, use esimd optimization
 | 
					            if self.weight_type != 3:
 | 
				
			||||||
            result = linear_fp16_esimd.forward(x_2d, self.weight.data)
 | 
					                # convert weight first to use esimd fp16 kernel
 | 
				
			||||||
 | 
					                self.convert_weight_for_esimd_kernel()
 | 
				
			||||||
 | 
					            # esimd fp16 kernel for inference
 | 
				
			||||||
 | 
					            x_shape = x.shape
 | 
				
			||||||
 | 
					            x_2d = x.view(-1, x_shape[-1])
 | 
				
			||||||
 | 
					            if x_2d.is_contiguous() is False:
 | 
				
			||||||
 | 
					                x_2d = x_2d.contiguous()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        new_shape = x_shape[:-1] + (self.out_len,)
 | 
					            try:
 | 
				
			||||||
        result = result.view(new_shape)
 | 
					                import intel_extension_for_pytorch
 | 
				
			||||||
        if self.mp_group is not None:
 | 
					                import linear_fp16_esimd
 | 
				
			||||||
            from deepspeed import comm as dist
 | 
					            except ModuleNotFoundError:
 | 
				
			||||||
            dist.inference_all_reduce(result, group=self.mp_group)
 | 
					                invalidInputError(False,
 | 
				
			||||||
        if self.bias is not None:
 | 
					                                  "Please `pip install bigdl_core_xe_esimd` first.")
 | 
				
			||||||
            result += self.bias
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return result.to(x.dtype)
 | 
					            if x_2d.shape[0] > 1:
 | 
				
			||||||
 | 
					                # first token or batch size > 1, re-convert weight
 | 
				
			||||||
 | 
					                original_weight = self.weight.data.transpose(1, 2)
 | 
				
			||||||
 | 
					                original_weight = original_weight.reshape(self.out_len, self.in_len)
 | 
				
			||||||
 | 
					                result = F.linear(x_2d, original_weight.contiguous())
 | 
				
			||||||
 | 
					                del original_weight
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                # rest token, use esimd optimization
 | 
				
			||||||
 | 
					                result = linear_fp16_esimd.forward(x_2d, self.weight.data)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            new_shape = x_shape[:-1] + (self.out_len,)
 | 
				
			||||||
 | 
					            result = result.view(new_shape)
 | 
				
			||||||
 | 
					            if self.mp_group is not None:
 | 
				
			||||||
 | 
					                from deepspeed import comm as dist
 | 
				
			||||||
 | 
					                dist.inference_all_reduce(result, group=self.mp_group)
 | 
				
			||||||
 | 
					            if self.bias is not None:
 | 
				
			||||||
 | 
					                result += self.bias
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            return result.to(x.dtype)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def use_esimd_kernel(self, x):
 | 
				
			||||||
 | 
					        gpu_type = get_xpu_device_type(x)
 | 
				
			||||||
 | 
					        # esimd kernel can only be used for Arc and Flex
 | 
				
			||||||
 | 
					        if gpu_type not in ["arc", "flex"]:
 | 
				
			||||||
 | 
					            return False
 | 
				
			||||||
 | 
					        # now esimd kernel can only be used for specific cases (llama2-7b shape)
 | 
				
			||||||
 | 
					        if self.in_len == 11008 and self.out_features == 4096:
 | 
				
			||||||
 | 
					            return True
 | 
				
			||||||
 | 
					        if self.in_len == 4096 and self.out_features in [4096, 11008]:
 | 
				
			||||||
 | 
					            # seems has some issue with Mistral,
 | 
				
			||||||
 | 
					            # need a further look to check whether can be used for other out features
 | 
				
			||||||
 | 
					            return True
 | 
				
			||||||
 | 
					        return False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def convert_weight_for_esimd_kernel(self):
 | 
				
			||||||
 | 
					        m, n = self.out_len, self.in_len
 | 
				
			||||||
 | 
					        if self.in_len == 11008:
 | 
				
			||||||
 | 
					            if self.weight_type == 2:
 | 
				
			||||||
 | 
					                trans_weight = self.weight.data.transpose(0, 1)
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                trans_weight = self.weight.data
 | 
				
			||||||
 | 
					            trans_weight = trans_weight.data.reshape(m//8, 8, n)
 | 
				
			||||||
 | 
					            trans_weight = trans_weight.transpose(1, 2).contiguous()
 | 
				
			||||||
 | 
					            self.weight.data = trans_weight
 | 
				
			||||||
 | 
					        elif self.in_len == 4096:
 | 
				
			||||||
 | 
					            if self.weight_type == 2:
 | 
				
			||||||
 | 
					                trans_weight = self.weight.data.transpose(0, 1)
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                trans_weight = self.weight.data
 | 
				
			||||||
 | 
					            trans_weight = trans_weight.data.reshape(m//16, 16, n)
 | 
				
			||||||
 | 
					            trans_weight = trans_weight.transpose(1, 2).contiguous()
 | 
				
			||||||
 | 
					            self.weight.data = trans_weight
 | 
				
			||||||
 | 
					        self.weight_type = 3
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class BF16Linear(nn.Linear):
 | 
					class BF16Linear(nn.Linear):
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -100,8 +100,9 @@ def llama_mlp_forward(
 | 
				
			||||||
    x: torch.Tensor,
 | 
					    x: torch.Tensor,
 | 
				
			||||||
) -> torch.Tensor:
 | 
					) -> torch.Tensor:
 | 
				
			||||||
    x_2d = x.view(-1, x.shape[-1])
 | 
					    x_2d = x.view(-1, x.shape[-1])
 | 
				
			||||||
 | 
					    qtype = getattr(self.gate_proj, "qtype", None)
 | 
				
			||||||
    if x_2d.shape[0] == 1 and x.device.type == 'xpu' \
 | 
					    if x_2d.shape[0] == 1 and x.device.type == 'xpu' \
 | 
				
			||||||
            and self.gate_proj.qtype == ggml_tensor_qtype["sym_int4"] \
 | 
					            and qtype == ggml_tensor_qtype["sym_int4"] \
 | 
				
			||||||
            and not (self.training and x.requires_grad):
 | 
					            and not (self.training and x.requires_grad):
 | 
				
			||||||
        import linear_q4_0
 | 
					        import linear_q4_0
 | 
				
			||||||
        if not x_2d.is_contiguous():
 | 
					        if not x_2d.is_contiguous():
 | 
				
			||||||
| 
						 | 
					@ -147,7 +148,8 @@ def llama_attention_forward_4_31(
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
 | 
					    use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
 | 
				
			||||||
    enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value)
 | 
					    enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value)
 | 
				
			||||||
    is_q4_0 = self.q_proj.qtype == SYM_INT4
 | 
					    qtype = getattr(self.q_proj, "qtype", None)
 | 
				
			||||||
 | 
					    is_q4_0 = qtype == SYM_INT4
 | 
				
			||||||
    no_tp = not self.config.pretraining_tp > 1
 | 
					    no_tp = not self.config.pretraining_tp > 1
 | 
				
			||||||
    decoding_fast_path = (no_tp and is_q4_0 and use_fuse_rope and
 | 
					    decoding_fast_path = (no_tp and is_q4_0 and use_fuse_rope and
 | 
				
			||||||
                          enough_kv_room and bsz * q_len == 1)
 | 
					                          enough_kv_room and bsz * q_len == 1)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -149,3 +149,29 @@ def get_autocast_dtype(x):
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        invalidInputError(False,
 | 
					        invalidInputError(False,
 | 
				
			||||||
                          f"Device {x.device} is not supported.")
 | 
					                          f"Device {x.device} is not supported.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					_ipex_version = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def get_ipex_version():
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    global _ipex_version
 | 
				
			||||||
 | 
					    if _ipex_version is not None:
 | 
				
			||||||
 | 
					        return _ipex_version
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    import intel_extension_for_pytorch as ipex
 | 
				
			||||||
 | 
					    _ipex_version = ipex.__version__
 | 
				
			||||||
 | 
					    return _ipex_version
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def get_xpu_device_type(x):
 | 
				
			||||||
 | 
					    name = torch.xpu.get_device_name(x.device.index)
 | 
				
			||||||
 | 
					    if name.startswith("Intel(R) Arc(TM) A"):
 | 
				
			||||||
 | 
					        return "arc"
 | 
				
			||||||
 | 
					    elif name.startswith("Intel(R) Data Center GPU Flex"):
 | 
				
			||||||
 | 
					        return "flex"
 | 
				
			||||||
 | 
					    elif name.startswith("Intel(R) Data Center GPU Max"):
 | 
				
			||||||
 | 
					        return "pvc"
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        return "others"
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue