support passing None to low_bit in optimize_model (#12121)

This commit is contained in:
Yishuo Wang 2024-09-26 11:09:35 +08:00 committed by GitHub
parent 47e0b83cbf
commit 77af9bc5fa
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 43 additions and 32 deletions

View file

@ -202,7 +202,7 @@ def optimize_model(model, low_bit='sym_int4', optimize_llm=True, modules_to_not_
:param model: The original PyTorch model (nn.module) :param model: The original PyTorch model (nn.module)
:param low_bit: str value, options are ``'sym_int4'``, ``'asym_int4'``, ``'sym_int5'``, :param low_bit: str value, options are ``'sym_int4'``, ``'asym_int4'``, ``'sym_int5'``,
``'asym_int5'``, ``'sym_int8'``, ``'nf3'``, ``'nf4'``, ``'fp4'``, ``'asym_int5'``, ``'sym_int8'``, ``'nf3'``, ``'nf4'``, ``'fp4'``,
``'fp8'``, ``'fp8_e4m3'``, ``'fp8_e5m2'``, ``'fp16'`` or ``'bf16'``, ``'fp8'``, ``'fp8_e4m3'``, ``'fp8_e5m2'``, ``'fp16'``, ``'bf16'`` or None,
``'sym_int4'`` means symmetric int 4, ``'asym_int4'`` means ``'sym_int4'`` means symmetric int 4, ``'asym_int4'`` means
asymmetric int 4, ``'nf4'`` means 4-bit NormalFloat, etc. asymmetric int 4, ``'nf4'`` means 4-bit NormalFloat, etc.
Relevant low bit optimizations will be applied to the model. Relevant low bit optimizations will be applied to the model.
@ -225,10 +225,11 @@ def optimize_model(model, low_bit='sym_int4', optimize_llm=True, modules_to_not_
>>> # (Optional) you can also save the optimized model by calling 'save_low_bit' >>> # (Optional) you can also save the optimized model by calling 'save_low_bit'
>>> model.save_low_bit(saved_dir) >>> model.save_low_bit(saved_dir)
""" """
invalidInputError(low_bit in ggml_tensor_qtype, invalidInputError(low_bit is None or low_bit in ggml_tensor_qtype,
f"Unknown load_in_low_bit value: {low_bit}, expected:" f"Unknown load_in_low_bit value: {low_bit}, expected:"
f" sym_int4, asym_int4, sym_int5, asym_int5 or sym_int8.") f" sym_int4, asym_int4, sym_int5, asym_int5 or sym_int8.")
invalidInputError(isinstance(model, torch.nn.Module), invalidInputError(isinstance(model, torch.nn.Module) or
model.__class__.__name__ == "StableDiffusionPipeline",
"model should be an instance of " "model should be an instance of "
f"`torch.nn.Module`, but got {type(model)} at last.") f"`torch.nn.Module`, but got {type(model)} at last.")
# To adapt vLLM models # To adapt vLLM models
@ -249,7 +250,7 @@ def optimize_model(model, low_bit='sym_int4', optimize_llm=True, modules_to_not_
torch_dtype = torch.float16 torch_dtype = torch.float16
else: else:
torch_dtype = kwargs.get("torch_dtype", "auto") torch_dtype = kwargs.get("torch_dtype", "auto")
qtype = ggml_tensor_qtype[low_bit] qtype = ggml_tensor_qtype[low_bit] if low_bit is not None else None
model = ggml_convert_low_bit(model, model = ggml_convert_low_bit(model,
qtype=qtype, qtype=qtype,
torch_dtype=torch_dtype, torch_dtype=torch_dtype,

View file

@ -1060,7 +1060,7 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
logger.info(f"Converting the current model to " logger.info(f"Converting the current model to "
f"{list(ggml_tensor_qtype.keys())[index]} " f"{list(ggml_tensor_qtype.keys())[index]} "
f"format......") f"format......")
else: elif qtype in gguf_mixed_qtype.values():
index = list(gguf_mixed_qtype.values()).index(qtype) index = list(gguf_mixed_qtype.values()).index(qtype)
logger.info(f"Converting the current model to " logger.info(f"Converting the current model to "
f"{list(gguf_mixed_qtype.keys())[index]} " f"{list(gguf_mixed_qtype.keys())[index]} "
@ -1089,6 +1089,7 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
enable_scale_search = use_scale_search(model_config, qtype) enable_scale_search = use_scale_search(model_config, qtype)
# mixed quantization needs model_config to choose custom quantization strategy # mixed quantization needs model_config to choose custom quantization strategy
if qtype is not None:
model, has_been_replaced = _replace_with_low_bit_linear( model, has_been_replaced = _replace_with_low_bit_linear(
model, qtype, modules_to_not_convert, model, qtype, modules_to_not_convert,
convert_shape_only, cpu_embedding, convert_shape_only, cpu_embedding,
@ -1221,6 +1222,15 @@ def _optimize_ipex(model, qtype=ggml_tensor_qtype["bf16"]):
def _optimize_post(model, lightweight_bmm=False): def _optimize_post(model, lightweight_bmm=False):
try:
from diffusers import StableDiffusionPipeline
if isinstance(model, StableDiffusionPipeline):
from ipex_llm.transformers.models.sd15 import AttnProcessor2_0
model.unet.set_attn_processor(AttnProcessor2_0())
return model
except ModuleNotFoundError:
pass
try: try:
from sentence_transformers.SentenceTransformer import SentenceTransformer from sentence_transformers.SentenceTransformer import SentenceTransformer
if isinstance(model, SentenceTransformer): if isinstance(model, SentenceTransformer):