small fix and remove ununsed code about ipex (#12671)
This commit is contained in:
		
							parent
							
								
									c11f5f0fcd
								
							
						
					
					
						commit
						a22a8c21bb
					
				
					 4 changed files with 4 additions and 28 deletions
				
			
		| 
						 | 
					@ -847,18 +847,9 @@ def replace_with_low_bit_linear_for_module(model, qtype, module_name=None,
 | 
				
			||||||
                        mp_group=mp_group,
 | 
					                        mp_group=mp_group,
 | 
				
			||||||
                    )
 | 
					                    )
 | 
				
			||||||
                    device = module.weight.data.device
 | 
					                    device = module.weight.data.device
 | 
				
			||||||
                    from ipex_llm.transformers.utils import get_ipex_version
 | 
					                    new_linear._parameters['weight'] = nn.Parameter(module.weight)
 | 
				
			||||||
                    if get_ipex_version() < "2.1.10+xpu":
 | 
					 | 
				
			||||||
                        new_linear._parameters['weight'] = nn.Parameter(module.weight)
 | 
					 | 
				
			||||||
                    else:
 | 
					 | 
				
			||||||
                        # only from 2.1, ipex provides matmul_bias_out
 | 
					 | 
				
			||||||
                        # so we need to transpose weight
 | 
					 | 
				
			||||||
                        new_weight = module.weight.transpose(0, 1).contiguous()
 | 
					 | 
				
			||||||
                        new_linear._parameters['weight'] = nn.Parameter(new_weight)
 | 
					 | 
				
			||||||
                        new_linear.weight_type = 2
 | 
					 | 
				
			||||||
                    if module.bias is not None:
 | 
					                    if module.bias is not None:
 | 
				
			||||||
                        new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
 | 
					                        new_linear._parameters['bias'] = nn.Parameter(module.bias.data).to(device)
 | 
				
			||||||
                            .to(device)
 | 
					 | 
				
			||||||
                elif qtype == ggml_tensor_qtype["bf16"]:
 | 
					                elif qtype == ggml_tensor_qtype["bf16"]:
 | 
				
			||||||
                    module.to(torch.bfloat16)
 | 
					                    module.to(torch.bfloat16)
 | 
				
			||||||
                    new_linear = BF16Linear(
 | 
					                    new_linear = BF16Linear(
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -51,8 +51,7 @@ from torch import Tensor, device, dtype, nn
 | 
				
			||||||
from operator import mul
 | 
					from operator import mul
 | 
				
			||||||
from functools import reduce
 | 
					from functools import reduce
 | 
				
			||||||
from ipex_llm.transformers.xpu_customize_fwd import custom_fwd, custom_bwd
 | 
					from ipex_llm.transformers.xpu_customize_fwd import custom_fwd, custom_bwd
 | 
				
			||||||
from ipex_llm.transformers.utils import get_autocast_dtype, get_xpu_device_name, \
 | 
					from ipex_llm.transformers.utils import get_autocast_dtype, get_xpu_device_name
 | 
				
			||||||
    get_ipex_version
 | 
					 | 
				
			||||||
from ipex_llm.transformers.convert import is_deepspeed_available, get_use_vllm
 | 
					from ipex_llm.transformers.convert import is_deepspeed_available, get_use_vllm
 | 
				
			||||||
 | 
					
 | 
				
			||||||
T = TypeVar("T", bound="torch.nn.Module")
 | 
					T = TypeVar("T", bound="torch.nn.Module")
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -19,7 +19,7 @@ import torch
 | 
				
			||||||
import warnings
 | 
					import warnings
 | 
				
			||||||
from ipex_llm.utils.common import invalidInputError
 | 
					from ipex_llm.utils.common import invalidInputError
 | 
				
			||||||
from ipex_llm.ggml.quantize import ggml_tensor_qtype
 | 
					from ipex_llm.ggml.quantize import ggml_tensor_qtype
 | 
				
			||||||
from ipex_llm.transformers.utils import get_ipex_version, get_xpu_device_name
 | 
					from ipex_llm.transformers.utils import get_xpu_device_name
 | 
				
			||||||
from ipex_llm.transformers.low_bit_linear import SYM_INT4, SYM_INT8, FP8E5, IQ2_XXS, FP4, FP8E4,\
 | 
					from ipex_llm.transformers.low_bit_linear import SYM_INT4, SYM_INT8, FP8E5, IQ2_XXS, FP4, FP8E4,\
 | 
				
			||||||
    FP6, ASYM_INT4
 | 
					    FP6, ASYM_INT4
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -154,20 +154,6 @@ def get_autocast_dtype(x):
 | 
				
			||||||
                          f"Device {x.device} is not supported.")
 | 
					                          f"Device {x.device} is not supported.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
_ipex_version = None
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def get_ipex_version():
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    global _ipex_version
 | 
					 | 
				
			||||||
    if _ipex_version is not None:
 | 
					 | 
				
			||||||
        return _ipex_version
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    import intel_extension_for_pytorch as ipex
 | 
					 | 
				
			||||||
    _ipex_version = ipex.__version__
 | 
					 | 
				
			||||||
    return _ipex_version
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def get_xpu_device_name(device: torch.device):
 | 
					def get_xpu_device_name(device: torch.device):
 | 
				
			||||||
    if device.type != "xpu":
 | 
					    if device.type != "xpu":
 | 
				
			||||||
        return device.type
 | 
					        return device.type
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue