Support the desc_act feature in GPTQ model (#10851)
				
					
				
			* support act_order * update versions * fix style * fix bug * clean up
This commit is contained in:
		
							parent
							
								
									dc27b3bc35
								
							
						
					
					
						commit
						1ce8d7bcd9
					
				
					 5 changed files with 56 additions and 20 deletions
				
			
		| 
						 | 
					@ -13,9 +13,9 @@ conda create -n llm python=3.11
 | 
				
			||||||
conda activate llm
 | 
					conda activate llm
 | 
				
			||||||
# below command will install intel_extension_for_pytorch==2.1.10+xpu as default
 | 
					# below command will install intel_extension_for_pytorch==2.1.10+xpu as default
 | 
				
			||||||
pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
 | 
					pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
 | 
				
			||||||
pip install transformers==4.34.0
 | 
					pip install transformers==4.37.0
 | 
				
			||||||
BUILD_CUDA_EXT=0 pip install git+https://github.com/PanQiWei/AutoGPTQ.git@1de9ab6
 | 
					pip install auto_gptq==0.7.1
 | 
				
			||||||
pip install optimum==0.14.0
 | 
					pip install optimum==1.14.0
 | 
				
			||||||
```
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
### 2. Configures OneAPI environment variables
 | 
					### 2. Configures OneAPI environment variables
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -18,7 +18,7 @@ import torch
 | 
				
			||||||
import time
 | 
					import time
 | 
				
			||||||
import argparse
 | 
					import argparse
 | 
				
			||||||
from ipex_llm.transformers import AutoModelForCausalLM
 | 
					from ipex_llm.transformers import AutoModelForCausalLM
 | 
				
			||||||
from transformers import LlamaTokenizer, GPTQConfig
 | 
					from transformers import AutoTokenizer, GPTQConfig
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# you could tune the prompt based on your own model,
 | 
					# you could tune the prompt based on your own model,
 | 
				
			||||||
# here the prompt tuning refers to https://huggingface.co/georgesung/llama2_7b_chat_uncensored#prompt-style
 | 
					# here the prompt tuning refers to https://huggingface.co/georgesung/llama2_7b_chat_uncensored#prompt-style
 | 
				
			||||||
| 
						 | 
					@ -30,7 +30,7 @@ LLAMA2_PROMPT_FORMAT = """### HUMAN:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if __name__ == '__main__':
 | 
					if __name__ == '__main__':
 | 
				
			||||||
    parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for Llama2 model')
 | 
					    parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for Llama2 model')
 | 
				
			||||||
    parser.add_argument('--repo-id-or-model-path', type=str, default="TheBloke/Llama-2-7B-GPTQ",
 | 
					    parser.add_argument('--repo-id-or-model-path', type=str, default="TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ",
 | 
				
			||||||
                        help='The huggingface repo id'
 | 
					                        help='The huggingface repo id'
 | 
				
			||||||
                             ', or the path to the huggingface checkpoint folder')
 | 
					                             ', or the path to the huggingface checkpoint folder')
 | 
				
			||||||
    parser.add_argument('--prompt', type=str, default="What is AI?",
 | 
					    parser.add_argument('--prompt', type=str, default="What is AI?",
 | 
				
			||||||
| 
						 | 
					@ -48,8 +48,9 @@ if __name__ == '__main__':
 | 
				
			||||||
                                                 torch_dtype=torch.float,
 | 
					                                                 torch_dtype=torch.float,
 | 
				
			||||||
                                                 trust_remote_code=True,).to("xpu")
 | 
					                                                 trust_remote_code=True,).to("xpu")
 | 
				
			||||||
    
 | 
					    
 | 
				
			||||||
 | 
					    print(model)
 | 
				
			||||||
    # Load tokenizer
 | 
					    # Load tokenizer
 | 
				
			||||||
    tokenizer = LlamaTokenizer.from_pretrained(model_path, trust_remote_code=True)
 | 
					    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
 | 
				
			||||||
    
 | 
					    
 | 
				
			||||||
    # Generate predicted tokens
 | 
					    # Generate predicted tokens
 | 
				
			||||||
    with torch.inference_mode():
 | 
					    with torch.inference_mode():
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -99,6 +99,11 @@ def is_lm_head(name, model_config, out_features):
 | 
				
			||||||
        return False
 | 
					        return False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def is_gptq_linear(module):
 | 
				
			||||||
 | 
					    return is_auto_gptq_available() and \
 | 
				
			||||||
 | 
					        (isinstance(module, QuantLinearCuda) or isinstance(module, QuantLinearCudaOld))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def is_linear_module(module):
 | 
					def is_linear_module(module):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    in_features = None
 | 
					    in_features = None
 | 
				
			||||||
| 
						 | 
					@ -122,7 +127,7 @@ def is_linear_module(module):
 | 
				
			||||||
            mp_group = None
 | 
					            mp_group = None
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            result = False
 | 
					            result = False
 | 
				
			||||||
    elif is_auto_gptq_available() and isinstance(module, QuantLinearCudaOld):
 | 
					    elif is_gptq_linear(module):
 | 
				
			||||||
        in_features = module.infeatures
 | 
					        in_features = module.infeatures
 | 
				
			||||||
        out_features = module.outfeatures
 | 
					        out_features = module.outfeatures
 | 
				
			||||||
        mp_group = None
 | 
					        mp_group = None
 | 
				
			||||||
| 
						 | 
					@ -153,7 +158,7 @@ def is_linear_module(module):
 | 
				
			||||||
    return result, (in_features, out_features, mp_group)
 | 
					    return result, (in_features, out_features, mp_group)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def convert_gptq(module, awq=False, llm_awq=False):
 | 
					def convert_gptq(module, awq=False, llm_awq=False, act_order=False):
 | 
				
			||||||
    from ipex_llm.transformers.low_bit_linear import get_block_size
 | 
					    from ipex_llm.transformers.low_bit_linear import get_block_size
 | 
				
			||||||
    Q4_1 = get_block_size("asym_int4")
 | 
					    Q4_1 = get_block_size("asym_int4")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -164,6 +169,8 @@ def convert_gptq(module, awq=False, llm_awq=False):
 | 
				
			||||||
        module.wf.unsqueeze(0)).to(torch.int16 if module.bits == 8 else torch.int8)
 | 
					        module.wf.unsqueeze(0)).to(torch.int16 if module.bits == 8 else torch.int8)
 | 
				
			||||||
    zeros = torch.bitwise_and(zeros, (2 ** module.bits) - 1)
 | 
					    zeros = torch.bitwise_and(zeros, (2 ** module.bits) - 1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    g_id_map = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if not awq:
 | 
					    if not awq:
 | 
				
			||||||
        zeros = zeros + 1
 | 
					        zeros = zeros + 1
 | 
				
			||||||
    zeros = zeros.reshape(scales.shape)
 | 
					    zeros = zeros.reshape(scales.shape)
 | 
				
			||||||
| 
						 | 
					@ -183,6 +190,12 @@ def convert_gptq(module, awq=False, llm_awq=False):
 | 
				
			||||||
        weight = torch.bitwise_and(weight, (2 ** module.bits) - 1)
 | 
					        weight = torch.bitwise_and(weight, (2 ** module.bits) - 1)
 | 
				
			||||||
        weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2])
 | 
					        weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if act_order:
 | 
				
			||||||
 | 
					            invalidInputError(module.g_idx.shape[0] == weight.shape[0],
 | 
				
			||||||
 | 
					                              "g_idx and weight shape mismatch")
 | 
				
			||||||
 | 
					            _, g_id_map = torch.sort(module.g_idx)
 | 
				
			||||||
 | 
					            weight = weight[g_id_map, :]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # convert weight to ggml format
 | 
					    # convert weight to ggml format
 | 
				
			||||||
    weight = weight.reshape(weight.shape[0]//module.group_size, module.group_size, weight.shape[1])
 | 
					    weight = weight.reshape(weight.shape[0]//module.group_size, module.group_size, weight.shape[1])
 | 
				
			||||||
    weight = weight.permute(2, 0, 1).reshape(weight.shape[2], -1, 2, Q4_1//2)
 | 
					    weight = weight.permute(2, 0, 1).reshape(weight.shape[2], -1, 2, Q4_1//2)
 | 
				
			||||||
| 
						 | 
					@ -219,7 +232,7 @@ def convert_gptq(module, awq=False, llm_awq=False):
 | 
				
			||||||
                             weight.view(torch.uint8)], dim=-1)
 | 
					                             weight.view(torch.uint8)], dim=-1)
 | 
				
			||||||
    ggml_weight = ggml_weight.reshape([-1])
 | 
					    ggml_weight = ggml_weight.reshape([-1])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return ggml_weight
 | 
					    return ggml_weight, g_id_map
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
					def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
				
			||||||
| 
						 | 
					@ -228,7 +241,9 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
				
			||||||
                                 imatrix_data=None, embedding_qtype=None,
 | 
					                                 imatrix_data=None, embedding_qtype=None,
 | 
				
			||||||
                                 model_config=None, torch_dtype=torch.float32,
 | 
					                                 model_config=None, torch_dtype=torch.float32,
 | 
				
			||||||
                                 enable_xetla=False,
 | 
					                                 enable_xetla=False,
 | 
				
			||||||
                                 mixed_precision=False):
 | 
					                                 mixed_precision=False,
 | 
				
			||||||
 | 
					                                 act_order=False,
 | 
				
			||||||
 | 
					                                 ):
 | 
				
			||||||
    from ipex_llm.transformers.low_bit_linear import LowBitLinear, FP4Params, \
 | 
					    from ipex_llm.transformers.low_bit_linear import LowBitLinear, FP4Params, \
 | 
				
			||||||
        FP16Linear, BF16Linear
 | 
					        FP16Linear, BF16Linear
 | 
				
			||||||
    from ipex_llm.transformers.embedding import LLMEmbedding, LowBitEmbedding
 | 
					    from ipex_llm.transformers.embedding import LLMEmbedding, LowBitEmbedding
 | 
				
			||||||
| 
						 | 
					@ -252,7 +267,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
				
			||||||
                    optimize_lm_head = True
 | 
					                    optimize_lm_head = True
 | 
				
			||||||
            with init_empty_weights():
 | 
					            with init_empty_weights():
 | 
				
			||||||
                new_linear = None
 | 
					                new_linear = None
 | 
				
			||||||
                is_gptq = is_auto_gptq_available() and isinstance(module, QuantLinearCudaOld)
 | 
					                is_gptq = is_gptq_linear(module)
 | 
				
			||||||
                is_awq = is_auto_awq_available() and isinstance(module, WQLinear_GEMM)
 | 
					                is_awq = is_auto_awq_available() and isinstance(module, WQLinear_GEMM)
 | 
				
			||||||
                is_llm_awq = is_awq and module.backend == AwqBackendPackingMethod.LLMAWQ
 | 
					                is_llm_awq = is_awq and module.backend == AwqBackendPackingMethod.LLMAWQ
 | 
				
			||||||
                if is_gptq or is_awq:
 | 
					                if is_gptq or is_awq:
 | 
				
			||||||
| 
						 | 
					@ -264,14 +279,20 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
				
			||||||
                        bias=has_bias,
 | 
					                        bias=has_bias,
 | 
				
			||||||
                        mp_group=mp_group,
 | 
					                        mp_group=mp_group,
 | 
				
			||||||
                        enable_xetla=enable_xetla,
 | 
					                        enable_xetla=enable_xetla,
 | 
				
			||||||
                        optimize_lm_head=optimize_lm_head
 | 
					                        optimize_lm_head=optimize_lm_head,
 | 
				
			||||||
 | 
					                        act_order=act_order,
 | 
				
			||||||
                    )
 | 
					                    )
 | 
				
			||||||
                    device = module.qweight.data.device
 | 
					                    device = module.qweight.data.device
 | 
				
			||||||
                    invalidInputError(device.type != "meta",
 | 
					                    invalidInputError(device.type != "meta",
 | 
				
			||||||
                                      "converting from meta device is not supported")
 | 
					                                      "converting from meta device is not supported")
 | 
				
			||||||
 | 
					                    weight, g_idx_map = convert_gptq(module,
 | 
				
			||||||
 | 
					                                                     awq=is_awq,
 | 
				
			||||||
 | 
					                                                     llm_awq=is_llm_awq,
 | 
				
			||||||
 | 
					                                                     act_order=act_order)
 | 
				
			||||||
 | 
					                    if act_order:
 | 
				
			||||||
 | 
					                        new_linear.g_idx_map = g_idx_map
 | 
				
			||||||
                    # Copy the weights
 | 
					                    # Copy the weights
 | 
				
			||||||
                    paramsLowBit = FP4Params(data=convert_gptq(module, awq=is_awq,
 | 
					                    paramsLowBit = FP4Params(data=weight,
 | 
				
			||||||
                                                               llm_awq=is_llm_awq),
 | 
					 | 
				
			||||||
                                             requires_grad=False,
 | 
					                                             requires_grad=False,
 | 
				
			||||||
                                             quantized=True,
 | 
					                                             quantized=True,
 | 
				
			||||||
                                             _shape=(out_features, in_features),
 | 
					                                             _shape=(out_features, in_features),
 | 
				
			||||||
| 
						 | 
					@ -422,7 +443,8 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
				
			||||||
                model_config=model_config,
 | 
					                model_config=model_config,
 | 
				
			||||||
                torch_dtype=torch_dtype,
 | 
					                torch_dtype=torch_dtype,
 | 
				
			||||||
                enable_xetla=enable_xetla,
 | 
					                enable_xetla=enable_xetla,
 | 
				
			||||||
                mixed_precision=mixed_precision
 | 
					                mixed_precision=mixed_precision,
 | 
				
			||||||
 | 
					                act_order=act_order,
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
            has_been_replaced = _flag or has_been_replaced
 | 
					            has_been_replaced = _flag or has_been_replaced
 | 
				
			||||||
    return model, has_been_replaced
 | 
					    return model, has_been_replaced
 | 
				
			||||||
| 
						 | 
					@ -464,7 +486,7 @@ def replace_with_low_bit_linear_for_module(model, qtype, module_name=None,
 | 
				
			||||||
            in_features, out_features, mp_group = linear_args
 | 
					            in_features, out_features, mp_group = linear_args
 | 
				
			||||||
            with init_empty_weights():
 | 
					            with init_empty_weights():
 | 
				
			||||||
                new_linear = None
 | 
					                new_linear = None
 | 
				
			||||||
                is_gptq = is_auto_gptq_available() and isinstance(module, QuantLinearCudaOld)
 | 
					                is_gptq = is_gptq_linear(module)
 | 
				
			||||||
                is_awq = is_auto_awq_available() and isinstance(module, WQLinear_GEMM)
 | 
					                is_awq = is_auto_awq_available() and isinstance(module, WQLinear_GEMM)
 | 
				
			||||||
                is_llm_awq = is_awq and module.backend == AwqBackendPackingMethod.LLMAWQ
 | 
					                is_llm_awq = is_awq and module.backend == AwqBackendPackingMethod.LLMAWQ
 | 
				
			||||||
                if is_gptq or is_awq:
 | 
					                if is_gptq or is_awq:
 | 
				
			||||||
| 
						 | 
					@ -721,6 +743,10 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
 | 
				
			||||||
    if optimize_model:
 | 
					    if optimize_model:
 | 
				
			||||||
        model = _optimize_pre(model)
 | 
					        model = _optimize_pre(model)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    act_order = False
 | 
				
			||||||
 | 
					    if getattr(model, "quantization_method", None) == "gptq":
 | 
				
			||||||
 | 
					        act_order = model.config.quantization_config.desc_act
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # mixed quantization needs model_config to choose custom quantization strategy
 | 
					    # mixed quantization needs model_config to choose custom quantization strategy
 | 
				
			||||||
    model, has_been_replaced = _replace_with_low_bit_linear(
 | 
					    model, has_been_replaced = _replace_with_low_bit_linear(
 | 
				
			||||||
        model, qtype, modules_to_not_convert,
 | 
					        model, qtype, modules_to_not_convert,
 | 
				
			||||||
| 
						 | 
					@ -731,6 +757,7 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
 | 
				
			||||||
        torch_dtype=torch_dtype,
 | 
					        torch_dtype=torch_dtype,
 | 
				
			||||||
        enable_xetla=enable_xetla,
 | 
					        enable_xetla=enable_xetla,
 | 
				
			||||||
        mixed_precision=mixed_precision,
 | 
					        mixed_precision=mixed_precision,
 | 
				
			||||||
 | 
					        act_order=act_order,
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    if not has_been_replaced:
 | 
					    if not has_been_replaced:
 | 
				
			||||||
        warnings.warn(
 | 
					        warnings.warn(
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -579,7 +579,7 @@ class MatMulLowBitCPU(torch.autograd.Function):
 | 
				
			||||||
class LowBitLinear(nn.Linear):
 | 
					class LowBitLinear(nn.Linear):
 | 
				
			||||||
    def __init__(self, input_features, output_features, qtype, bias=True,
 | 
					    def __init__(self, input_features, output_features, qtype, bias=True,
 | 
				
			||||||
                 conver_to_half=True, mp_group=None, enable_xetla=False,
 | 
					                 conver_to_half=True, mp_group=None, enable_xetla=False,
 | 
				
			||||||
                 optimize_lm_head=False):
 | 
					                 optimize_lm_head=False, act_order=False):
 | 
				
			||||||
        super().__init__(input_features, output_features, bias)
 | 
					        super().__init__(input_features, output_features, bias)
 | 
				
			||||||
        self.weight = FP4Params(self.weight.data,
 | 
					        self.weight = FP4Params(self.weight.data,
 | 
				
			||||||
                                requires_grad=False,
 | 
					                                requires_grad=False,
 | 
				
			||||||
| 
						 | 
					@ -603,6 +603,11 @@ class LowBitLinear(nn.Linear):
 | 
				
			||||||
        # since performance isn't impacted.
 | 
					        # since performance isn't impacted.
 | 
				
			||||||
        self.is_lm_head = self.in_len * self.out_len >= 32000 * 4096 and self.bias is None
 | 
					        self.is_lm_head = self.in_len * self.out_len >= 32000 * 4096 and self.bias is None
 | 
				
			||||||
        self.low_memory_mode = self.is_lm_head
 | 
					        self.low_memory_mode = self.is_lm_head
 | 
				
			||||||
 | 
					        self.act_order = act_order
 | 
				
			||||||
 | 
					        if act_order:
 | 
				
			||||||
 | 
					            self.register_buffer(
 | 
				
			||||||
 | 
					                "g_idx_map",
 | 
				
			||||||
 | 
					                torch.tensor([i for i in range(self.in_len)], dtype=torch.int64))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def forward(self, x: torch.Tensor):
 | 
					    def forward(self, x: torch.Tensor):
 | 
				
			||||||
        # empty cache before and after lm_head at first token when input > 1024
 | 
					        # empty cache before and after lm_head at first token when input > 1024
 | 
				
			||||||
| 
						 | 
					@ -640,6 +645,9 @@ class LowBitLinear(nn.Linear):
 | 
				
			||||||
            return torch.empty(new_shape, dtype=x.dtype, device=x.device)
 | 
					            return torch.empty(new_shape, dtype=x.dtype, device=x.device)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        x_2d = x.view(-1, x_shape[-1])
 | 
					        x_2d = x.view(-1, x_shape[-1])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.act_order:
 | 
				
			||||||
 | 
					            x_2d = x_2d[:, self.g_idx_map]
 | 
				
			||||||
        # x0 for weight
 | 
					        # x0 for weight
 | 
				
			||||||
        x0 = self.weight.data
 | 
					        x0 = self.weight.data
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -243,8 +243,6 @@ class _BaseAutoModelClass:
 | 
				
			||||||
                if q_config["quant_method"] == "gptq":
 | 
					                if q_config["quant_method"] == "gptq":
 | 
				
			||||||
                    invalidInputError(q_config["bits"] == 4,
 | 
					                    invalidInputError(q_config["bits"] == 4,
 | 
				
			||||||
                                      "Only 4-bit gptq is supported in bigdl-llm.")
 | 
					                                      "Only 4-bit gptq is supported in bigdl-llm.")
 | 
				
			||||||
                    invalidInputError(q_config["desc_act"] is False,
 | 
					 | 
				
			||||||
                                      "Only desc_act=False is supported in bigdl-llm.")
 | 
					 | 
				
			||||||
                    if load_in_low_bit is not None:
 | 
					                    if load_in_low_bit is not None:
 | 
				
			||||||
                        invalidInputError(load_in_low_bit == "asym_int4",
 | 
					                        invalidInputError(load_in_low_bit == "asym_int4",
 | 
				
			||||||
                                          "You can only load gptq model as aysm_int4 low bit type.")
 | 
					                                          "You can only load gptq model as aysm_int4 low bit type.")
 | 
				
			||||||
| 
						 | 
					@ -448,6 +446,8 @@ class _BaseAutoModelClass:
 | 
				
			||||||
                offload_dir=None
 | 
					                offload_dir=None
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
 | 
					            if quant_config is not None:
 | 
				
			||||||
 | 
					                kwargs["quantization_config"] = quant_config
 | 
				
			||||||
            _load_pre()
 | 
					            _load_pre()
 | 
				
			||||||
            try:
 | 
					            try:
 | 
				
			||||||
                # To handle the input CUDA setting (such as 'device_map={"":0}'), ignore it
 | 
					                # To handle the input CUDA setting (such as 'device_map={"":0}'), ignore it
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue