support passing None to low_bit in optimize_model (#12121)
This commit is contained in:
parent
47e0b83cbf
commit
77af9bc5fa
2 changed files with 43 additions and 32 deletions
|
|
@ -202,7 +202,7 @@ def optimize_model(model, low_bit='sym_int4', optimize_llm=True, modules_to_not_
|
||||||
:param model: The original PyTorch model (nn.module)
|
:param model: The original PyTorch model (nn.module)
|
||||||
:param low_bit: str value, options are ``'sym_int4'``, ``'asym_int4'``, ``'sym_int5'``,
|
:param low_bit: str value, options are ``'sym_int4'``, ``'asym_int4'``, ``'sym_int5'``,
|
||||||
``'asym_int5'``, ``'sym_int8'``, ``'nf3'``, ``'nf4'``, ``'fp4'``,
|
``'asym_int5'``, ``'sym_int8'``, ``'nf3'``, ``'nf4'``, ``'fp4'``,
|
||||||
``'fp8'``, ``'fp8_e4m3'``, ``'fp8_e5m2'``, ``'fp16'`` or ``'bf16'``,
|
``'fp8'``, ``'fp8_e4m3'``, ``'fp8_e5m2'``, ``'fp16'``, ``'bf16'`` or None,
|
||||||
``'sym_int4'`` means symmetric int 4, ``'asym_int4'`` means
|
``'sym_int4'`` means symmetric int 4, ``'asym_int4'`` means
|
||||||
asymmetric int 4, ``'nf4'`` means 4-bit NormalFloat, etc.
|
asymmetric int 4, ``'nf4'`` means 4-bit NormalFloat, etc.
|
||||||
Relevant low bit optimizations will be applied to the model.
|
Relevant low bit optimizations will be applied to the model.
|
||||||
|
|
@ -225,10 +225,11 @@ def optimize_model(model, low_bit='sym_int4', optimize_llm=True, modules_to_not_
|
||||||
>>> # (Optional) you can also save the optimized model by calling 'save_low_bit'
|
>>> # (Optional) you can also save the optimized model by calling 'save_low_bit'
|
||||||
>>> model.save_low_bit(saved_dir)
|
>>> model.save_low_bit(saved_dir)
|
||||||
"""
|
"""
|
||||||
invalidInputError(low_bit in ggml_tensor_qtype,
|
invalidInputError(low_bit is None or low_bit in ggml_tensor_qtype,
|
||||||
f"Unknown load_in_low_bit value: {low_bit}, expected:"
|
f"Unknown load_in_low_bit value: {low_bit}, expected:"
|
||||||
f" sym_int4, asym_int4, sym_int5, asym_int5 or sym_int8.")
|
f" sym_int4, asym_int4, sym_int5, asym_int5 or sym_int8.")
|
||||||
invalidInputError(isinstance(model, torch.nn.Module),
|
invalidInputError(isinstance(model, torch.nn.Module) or
|
||||||
|
model.__class__.__name__ == "StableDiffusionPipeline",
|
||||||
"model should be an instance of "
|
"model should be an instance of "
|
||||||
f"`torch.nn.Module`, but got {type(model)} at last.")
|
f"`torch.nn.Module`, but got {type(model)} at last.")
|
||||||
# To adapt vLLM models
|
# To adapt vLLM models
|
||||||
|
|
@ -249,7 +250,7 @@ def optimize_model(model, low_bit='sym_int4', optimize_llm=True, modules_to_not_
|
||||||
torch_dtype = torch.float16
|
torch_dtype = torch.float16
|
||||||
else:
|
else:
|
||||||
torch_dtype = kwargs.get("torch_dtype", "auto")
|
torch_dtype = kwargs.get("torch_dtype", "auto")
|
||||||
qtype = ggml_tensor_qtype[low_bit]
|
qtype = ggml_tensor_qtype[low_bit] if low_bit is not None else None
|
||||||
model = ggml_convert_low_bit(model,
|
model = ggml_convert_low_bit(model,
|
||||||
qtype=qtype,
|
qtype=qtype,
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
|
|
|
||||||
|
|
@ -1060,7 +1060,7 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
|
||||||
logger.info(f"Converting the current model to "
|
logger.info(f"Converting the current model to "
|
||||||
f"{list(ggml_tensor_qtype.keys())[index]} "
|
f"{list(ggml_tensor_qtype.keys())[index]} "
|
||||||
f"format......")
|
f"format......")
|
||||||
else:
|
elif qtype in gguf_mixed_qtype.values():
|
||||||
index = list(gguf_mixed_qtype.values()).index(qtype)
|
index = list(gguf_mixed_qtype.values()).index(qtype)
|
||||||
logger.info(f"Converting the current model to "
|
logger.info(f"Converting the current model to "
|
||||||
f"{list(gguf_mixed_qtype.keys())[index]} "
|
f"{list(gguf_mixed_qtype.keys())[index]} "
|
||||||
|
|
@ -1089,34 +1089,35 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
|
||||||
enable_scale_search = use_scale_search(model_config, qtype)
|
enable_scale_search = use_scale_search(model_config, qtype)
|
||||||
|
|
||||||
# mixed quantization needs model_config to choose custom quantization strategy
|
# mixed quantization needs model_config to choose custom quantization strategy
|
||||||
model, has_been_replaced = _replace_with_low_bit_linear(
|
if qtype is not None:
|
||||||
model, qtype, modules_to_not_convert,
|
model, has_been_replaced = _replace_with_low_bit_linear(
|
||||||
convert_shape_only, cpu_embedding,
|
model, qtype, modules_to_not_convert,
|
||||||
imatrix_data=imatrix_data,
|
convert_shape_only, cpu_embedding,
|
||||||
embedding_qtype=embedding_qtype,
|
imatrix_data=imatrix_data,
|
||||||
model_config=model_config,
|
embedding_qtype=embedding_qtype,
|
||||||
torch_dtype=torch_dtype,
|
model_config=model_config,
|
||||||
enable_xetla=enable_xetla,
|
torch_dtype=torch_dtype,
|
||||||
mixed_precision=mixed_precision,
|
enable_xetla=enable_xetla,
|
||||||
act_order=act_order,
|
mixed_precision=mixed_precision,
|
||||||
enable_scale_search=enable_scale_search,
|
act_order=act_order,
|
||||||
)
|
enable_scale_search=enable_scale_search,
|
||||||
if not has_been_replaced:
|
|
||||||
warnings.warn(
|
|
||||||
"No linear modules were found in "
|
|
||||||
"your model. This can happen for some architectures such as gpt2 that uses Conv1D "
|
|
||||||
"instead of Linear layers. Please double check your model architecture, or submit "
|
|
||||||
"an issue on github if you think this is a bug."
|
|
||||||
)
|
)
|
||||||
elif device == "cpu":
|
if not has_been_replaced:
|
||||||
if not (getattr(model, "quantization_method", None) == "gptq"):
|
warnings.warn(
|
||||||
if torch_dtype == "auto":
|
"No linear modules were found in "
|
||||||
convert_bigdl_other_module(model, torch.float32)
|
"your model. This can happen for some architectures such as gpt2 that uses Conv1D "
|
||||||
else:
|
"instead of Linear layers. Please double check your model architecture, or submit "
|
||||||
convert_bigdl_other_module(model, torch_dtype)
|
"an issue on github if you think this is a bug."
|
||||||
elif device == "meta":
|
)
|
||||||
# Do nothing here for weights are empty.
|
elif device == "cpu":
|
||||||
pass
|
if not (getattr(model, "quantization_method", None) == "gptq"):
|
||||||
|
if torch_dtype == "auto":
|
||||||
|
convert_bigdl_other_module(model, torch.float32)
|
||||||
|
else:
|
||||||
|
convert_bigdl_other_module(model, torch_dtype)
|
||||||
|
elif device == "meta":
|
||||||
|
# Do nothing here for weights are empty.
|
||||||
|
pass
|
||||||
|
|
||||||
if optimize_model:
|
if optimize_model:
|
||||||
model = _optimize_post(model, lightweight_bmm)
|
model = _optimize_post(model, lightweight_bmm)
|
||||||
|
|
@ -1221,6 +1222,15 @@ def _optimize_ipex(model, qtype=ggml_tensor_qtype["bf16"]):
|
||||||
|
|
||||||
|
|
||||||
def _optimize_post(model, lightweight_bmm=False):
|
def _optimize_post(model, lightweight_bmm=False):
|
||||||
|
try:
|
||||||
|
from diffusers import StableDiffusionPipeline
|
||||||
|
if isinstance(model, StableDiffusionPipeline):
|
||||||
|
from ipex_llm.transformers.models.sd15 import AttnProcessor2_0
|
||||||
|
model.unet.set_attn_processor(AttnProcessor2_0())
|
||||||
|
return model
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
pass
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from sentence_transformers.SentenceTransformer import SentenceTransformer
|
from sentence_transformers.SentenceTransformer import SentenceTransformer
|
||||||
if isinstance(model, SentenceTransformer):
|
if isinstance(model, SentenceTransformer):
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue