support passing None to low_bit in optimize_model (#12121)

This commit is contained in:
Yishuo Wang 2024-09-26 11:09:35 +08:00 committed by GitHub
parent 47e0b83cbf
commit 77af9bc5fa
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 43 additions and 32 deletions

View file

@ -202,7 +202,7 @@ def optimize_model(model, low_bit='sym_int4', optimize_llm=True, modules_to_not_
:param model: The original PyTorch model (nn.module)
:param low_bit: str value, options are ``'sym_int4'``, ``'asym_int4'``, ``'sym_int5'``,
``'asym_int5'``, ``'sym_int8'``, ``'nf3'``, ``'nf4'``, ``'fp4'``,
``'fp8'``, ``'fp8_e4m3'``, ``'fp8_e5m2'``, ``'fp16'`` or ``'bf16'``,
``'fp8'``, ``'fp8_e4m3'``, ``'fp8_e5m2'``, ``'fp16'``, ``'bf16'`` or None,
``'sym_int4'`` means symmetric int 4, ``'asym_int4'`` means
asymmetric int 4, ``'nf4'`` means 4-bit NormalFloat, etc.
Relevant low bit optimizations will be applied to the model.
@ -225,10 +225,11 @@ def optimize_model(model, low_bit='sym_int4', optimize_llm=True, modules_to_not_
>>> # (Optional) you can also save the optimized model by calling 'save_low_bit'
>>> model.save_low_bit(saved_dir)
"""
invalidInputError(low_bit in ggml_tensor_qtype,
invalidInputError(low_bit is None or low_bit in ggml_tensor_qtype,
f"Unknown load_in_low_bit value: {low_bit}, expected:"
f" sym_int4, asym_int4, sym_int5, asym_int5 or sym_int8.")
invalidInputError(isinstance(model, torch.nn.Module),
invalidInputError(isinstance(model, torch.nn.Module) or
model.__class__.__name__ == "StableDiffusionPipeline",
"model should be an instance of "
f"`torch.nn.Module`, but got {type(model)} at last.")
# To adapt vLLM models
@ -249,7 +250,7 @@ def optimize_model(model, low_bit='sym_int4', optimize_llm=True, modules_to_not_
torch_dtype = torch.float16
else:
torch_dtype = kwargs.get("torch_dtype", "auto")
qtype = ggml_tensor_qtype[low_bit]
qtype = ggml_tensor_qtype[low_bit] if low_bit is not None else None
model = ggml_convert_low_bit(model,
qtype=qtype,
torch_dtype=torch_dtype,

View file

@ -1060,7 +1060,7 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
logger.info(f"Converting the current model to "
f"{list(ggml_tensor_qtype.keys())[index]} "
f"format......")
else:
elif qtype in gguf_mixed_qtype.values():
index = list(gguf_mixed_qtype.values()).index(qtype)
logger.info(f"Converting the current model to "
f"{list(gguf_mixed_qtype.keys())[index]} "
@ -1089,34 +1089,35 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
enable_scale_search = use_scale_search(model_config, qtype)
# mixed quantization needs model_config to choose custom quantization strategy
model, has_been_replaced = _replace_with_low_bit_linear(
model, qtype, modules_to_not_convert,
convert_shape_only, cpu_embedding,
imatrix_data=imatrix_data,
embedding_qtype=embedding_qtype,
model_config=model_config,
torch_dtype=torch_dtype,
enable_xetla=enable_xetla,
mixed_precision=mixed_precision,
act_order=act_order,
enable_scale_search=enable_scale_search,
)
if not has_been_replaced:
warnings.warn(
"No linear modules were found in "
"your model. This can happen for some architectures such as gpt2 that uses Conv1D "
"instead of Linear layers. Please double check your model architecture, or submit "
"an issue on github if you think this is a bug."
if qtype is not None:
model, has_been_replaced = _replace_with_low_bit_linear(
model, qtype, modules_to_not_convert,
convert_shape_only, cpu_embedding,
imatrix_data=imatrix_data,
embedding_qtype=embedding_qtype,
model_config=model_config,
torch_dtype=torch_dtype,
enable_xetla=enable_xetla,
mixed_precision=mixed_precision,
act_order=act_order,
enable_scale_search=enable_scale_search,
)
elif device == "cpu":
if not (getattr(model, "quantization_method", None) == "gptq"):
if torch_dtype == "auto":
convert_bigdl_other_module(model, torch.float32)
else:
convert_bigdl_other_module(model, torch_dtype)
elif device == "meta":
# Do nothing here for weights are empty.
pass
if not has_been_replaced:
warnings.warn(
"No linear modules were found in "
"your model. This can happen for some architectures such as gpt2 that uses Conv1D "
"instead of Linear layers. Please double check your model architecture, or submit "
"an issue on github if you think this is a bug."
)
elif device == "cpu":
if not (getattr(model, "quantization_method", None) == "gptq"):
if torch_dtype == "auto":
convert_bigdl_other_module(model, torch.float32)
else:
convert_bigdl_other_module(model, torch_dtype)
elif device == "meta":
# Do nothing here for weights are empty.
pass
if optimize_model:
model = _optimize_post(model, lightweight_bmm)
@ -1221,6 +1222,15 @@ def _optimize_ipex(model, qtype=ggml_tensor_qtype["bf16"]):
def _optimize_post(model, lightweight_bmm=False):
try:
from diffusers import StableDiffusionPipeline
if isinstance(model, StableDiffusionPipeline):
from ipex_llm.transformers.models.sd15 import AttnProcessor2_0
model.unet.set_attn_processor(AttnProcessor2_0())
return model
except ModuleNotFoundError:
pass
try:
from sentence_transformers.SentenceTransformer import SentenceTransformer
if isinstance(model, SentenceTransformer):