From 77af9bc5fac79b8295fed48f412b522379552179 Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Thu, 26 Sep 2024 11:09:35 +0800 Subject: [PATCH] support passing None to low_bit in optimize_model (#12121) --- python/llm/src/ipex_llm/optimize.py | 9 +-- .../llm/src/ipex_llm/transformers/convert.py | 66 +++++++++++-------- 2 files changed, 43 insertions(+), 32 deletions(-) diff --git a/python/llm/src/ipex_llm/optimize.py b/python/llm/src/ipex_llm/optimize.py index 86db591c..e010646c 100644 --- a/python/llm/src/ipex_llm/optimize.py +++ b/python/llm/src/ipex_llm/optimize.py @@ -202,7 +202,7 @@ def optimize_model(model, low_bit='sym_int4', optimize_llm=True, modules_to_not_ :param model: The original PyTorch model (nn.module) :param low_bit: str value, options are ``'sym_int4'``, ``'asym_int4'``, ``'sym_int5'``, ``'asym_int5'``, ``'sym_int8'``, ``'nf3'``, ``'nf4'``, ``'fp4'``, - ``'fp8'``, ``'fp8_e4m3'``, ``'fp8_e5m2'``, ``'fp16'`` or ``'bf16'``, + ``'fp8'``, ``'fp8_e4m3'``, ``'fp8_e5m2'``, ``'fp16'``, ``'bf16'`` or None, ``'sym_int4'`` means symmetric int 4, ``'asym_int4'`` means asymmetric int 4, ``'nf4'`` means 4-bit NormalFloat, etc. Relevant low bit optimizations will be applied to the model. @@ -225,10 +225,11 @@ def optimize_model(model, low_bit='sym_int4', optimize_llm=True, modules_to_not_ >>> # (Optional) you can also save the optimized model by calling 'save_low_bit' >>> model.save_low_bit(saved_dir) """ - invalidInputError(low_bit in ggml_tensor_qtype, + invalidInputError(low_bit is None or low_bit in ggml_tensor_qtype, f"Unknown load_in_low_bit value: {low_bit}, expected:" f" sym_int4, asym_int4, sym_int5, asym_int5 or sym_int8.") - invalidInputError(isinstance(model, torch.nn.Module), + invalidInputError(isinstance(model, torch.nn.Module) or + model.__class__.__name__ == "StableDiffusionPipeline", "model should be an instance of " f"`torch.nn.Module`, but got {type(model)} at last.") # To adapt vLLM models @@ -249,7 +250,7 @@ def optimize_model(model, low_bit='sym_int4', optimize_llm=True, modules_to_not_ torch_dtype = torch.float16 else: torch_dtype = kwargs.get("torch_dtype", "auto") - qtype = ggml_tensor_qtype[low_bit] + qtype = ggml_tensor_qtype[low_bit] if low_bit is not None else None model = ggml_convert_low_bit(model, qtype=qtype, torch_dtype=torch_dtype, diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index e77bd87a..dcb86080 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -1060,7 +1060,7 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True, logger.info(f"Converting the current model to " f"{list(ggml_tensor_qtype.keys())[index]} " f"format......") - else: + elif qtype in gguf_mixed_qtype.values(): index = list(gguf_mixed_qtype.values()).index(qtype) logger.info(f"Converting the current model to " f"{list(gguf_mixed_qtype.keys())[index]} " @@ -1089,34 +1089,35 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True, enable_scale_search = use_scale_search(model_config, qtype) # mixed quantization needs model_config to choose custom quantization strategy - model, has_been_replaced = _replace_with_low_bit_linear( - model, qtype, modules_to_not_convert, - convert_shape_only, cpu_embedding, - imatrix_data=imatrix_data, - embedding_qtype=embedding_qtype, - model_config=model_config, - torch_dtype=torch_dtype, - enable_xetla=enable_xetla, - mixed_precision=mixed_precision, - act_order=act_order, - enable_scale_search=enable_scale_search, - ) - if not has_been_replaced: - warnings.warn( - "No linear modules were found in " - "your model. This can happen for some architectures such as gpt2 that uses Conv1D " - "instead of Linear layers. Please double check your model architecture, or submit " - "an issue on github if you think this is a bug." + if qtype is not None: + model, has_been_replaced = _replace_with_low_bit_linear( + model, qtype, modules_to_not_convert, + convert_shape_only, cpu_embedding, + imatrix_data=imatrix_data, + embedding_qtype=embedding_qtype, + model_config=model_config, + torch_dtype=torch_dtype, + enable_xetla=enable_xetla, + mixed_precision=mixed_precision, + act_order=act_order, + enable_scale_search=enable_scale_search, ) - elif device == "cpu": - if not (getattr(model, "quantization_method", None) == "gptq"): - if torch_dtype == "auto": - convert_bigdl_other_module(model, torch.float32) - else: - convert_bigdl_other_module(model, torch_dtype) - elif device == "meta": - # Do nothing here for weights are empty. - pass + if not has_been_replaced: + warnings.warn( + "No linear modules were found in " + "your model. This can happen for some architectures such as gpt2 that uses Conv1D " + "instead of Linear layers. Please double check your model architecture, or submit " + "an issue on github if you think this is a bug." + ) + elif device == "cpu": + if not (getattr(model, "quantization_method", None) == "gptq"): + if torch_dtype == "auto": + convert_bigdl_other_module(model, torch.float32) + else: + convert_bigdl_other_module(model, torch_dtype) + elif device == "meta": + # Do nothing here for weights are empty. + pass if optimize_model: model = _optimize_post(model, lightweight_bmm) @@ -1221,6 +1222,15 @@ def _optimize_ipex(model, qtype=ggml_tensor_qtype["bf16"]): def _optimize_post(model, lightweight_bmm=False): + try: + from diffusers import StableDiffusionPipeline + if isinstance(model, StableDiffusionPipeline): + from ipex_llm.transformers.models.sd15 import AttnProcessor2_0 + model.unet.set_attn_processor(AttnProcessor2_0()) + return model + except ModuleNotFoundError: + pass + try: from sentence_transformers.SentenceTransformer import SentenceTransformer if isinstance(model, SentenceTransformer):