add save & load support for NPU optimized model (#11999)

* add save &  load support

* fix style
This commit is contained in:
Ruonan Wang 2024-09-03 05:53:22 -07:00 committed by GitHub
parent 6eb55653ba
commit 9eaff5e47d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -174,6 +174,7 @@ class _BaseAutoModelClass:
intra_pp=intra_pp, intra_pp=intra_pp,
transpose_value_cache=transpose_value_cache, transpose_value_cache=transpose_value_cache,
) )
model.save_low_bit = types.MethodType(save_low_bit, model)
else: else:
from ipex_llm.transformers.npu_models.convert import optimize_llm from ipex_llm.transformers.npu_models.convert import optimize_llm
optimize_llm(model) optimize_llm(model)
@ -209,10 +210,16 @@ class _BaseAutoModelClass:
ignore_argument(kwargs, "lightweight_bmm") ignore_argument(kwargs, "lightweight_bmm")
ignore_argument(kwargs, "cpu_embedding") ignore_argument(kwargs, "cpu_embedding")
ignore_argument(kwargs, "embedding_qtype") ignore_argument(kwargs, "embedding_qtype")
ignore_argument(kwargs, "optimize_model")
ignore_argument(kwargs, "modules_to_not_convert") ignore_argument(kwargs, "modules_to_not_convert")
ignore_argument(kwargs, "speculative") ignore_argument(kwargs, "speculative")
ignore_argument(kwargs, "pipeline_parallel_stages") ignore_argument(kwargs, "pipeline_parallel_stages")
optimize_model = kwargs.pop("optimize_model", False)
max_output_len = kwargs.pop("max_output_len", 1024)
max_prompt_len = kwargs.pop("max_prompt_len", 512)
inter_pp = kwargs.pop("inter_pp", None)
intra_pp = kwargs.pop("intra_pp", None)
transpose_value_cache = kwargs.pop("transpose_value_cache", True)
modules_to_not_convert = kwargs.pop("modules_to_not_convert", [])
from transformers.models.auto.configuration_auto import AutoConfig from transformers.models.auto.configuration_auto import AutoConfig
from transformers.modeling_utils import no_init_weights, get_state_dict_dtype from transformers.modeling_utils import no_init_weights, get_state_dict_dtype
@ -351,12 +358,34 @@ class _BaseAutoModelClass:
logger.info(f"Converting model, it may takes up to several minutes ...") logger.info(f"Converting model, it may takes up to several minutes ...")
from intel_npu_acceleration_library.compiler import create_npu_kernels from intel_npu_acceleration_library.compiler import create_npu_kernels
with torch.no_grad(): if optimize_model:
optimize_llm(model) invalidInputError(
cls.load_convert(qtype, model, quant_device, *model_args, **kwargs) max_prompt_len < max_output_len,
create_npu_kernels(model) (
f"max_prompt_len ({max_prompt_len}) should be less"
" than max_output_len ({max_output_len})"
),
)
from ipex_llm.transformers.npu_models.convert_mp import optimize_llm_pre
model = model.eval() if hasattr(model, "llm"):
llm = model.llm
else:
llm = model
with torch.no_grad():
optimize_llm_pre(model, qtype)
cls.load_convert(qtype, model, quant_device, modules_to_not_convert,
*model_args, **kwargs)
create_npu_kernels(llm)
else:
from ipex_llm.transformers.npu_models.convert import optimize_llm
optimize_llm(model)
with torch.no_grad():
cls.load_convert(qtype, model, quant_device, modules_to_not_convert,
*model_args, **kwargs)
create_npu_kernels(model)
if is_sharded: if is_sharded:
loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"] loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]
@ -415,6 +444,17 @@ class _BaseAutoModelClass:
for param in model.parameters(): for param in model.parameters():
param.requires_grad_(False) param.requires_grad_(False)
if optimize_model:
from ipex_llm.transformers.npu_models.convert_mp import optimize_llm
optimize_llm(
llm,
max_output_len=max_output_len,
max_prompt_len=max_prompt_len,
inter_pp=inter_pp,
intra_pp=intra_pp,
transpose_value_cache=transpose_value_cache,
)
return model return model