add save & load support for NPU optimized model (#11999)
* add save & load support * fix style
This commit is contained in:
parent
6eb55653ba
commit
9eaff5e47d
1 changed files with 46 additions and 6 deletions
|
|
@ -174,6 +174,7 @@ class _BaseAutoModelClass:
|
|||
intra_pp=intra_pp,
|
||||
transpose_value_cache=transpose_value_cache,
|
||||
)
|
||||
model.save_low_bit = types.MethodType(save_low_bit, model)
|
||||
else:
|
||||
from ipex_llm.transformers.npu_models.convert import optimize_llm
|
||||
optimize_llm(model)
|
||||
|
|
@ -209,10 +210,16 @@ class _BaseAutoModelClass:
|
|||
ignore_argument(kwargs, "lightweight_bmm")
|
||||
ignore_argument(kwargs, "cpu_embedding")
|
||||
ignore_argument(kwargs, "embedding_qtype")
|
||||
ignore_argument(kwargs, "optimize_model")
|
||||
ignore_argument(kwargs, "modules_to_not_convert")
|
||||
ignore_argument(kwargs, "speculative")
|
||||
ignore_argument(kwargs, "pipeline_parallel_stages")
|
||||
optimize_model = kwargs.pop("optimize_model", False)
|
||||
max_output_len = kwargs.pop("max_output_len", 1024)
|
||||
max_prompt_len = kwargs.pop("max_prompt_len", 512)
|
||||
inter_pp = kwargs.pop("inter_pp", None)
|
||||
intra_pp = kwargs.pop("intra_pp", None)
|
||||
transpose_value_cache = kwargs.pop("transpose_value_cache", True)
|
||||
modules_to_not_convert = kwargs.pop("modules_to_not_convert", [])
|
||||
|
||||
from transformers.models.auto.configuration_auto import AutoConfig
|
||||
from transformers.modeling_utils import no_init_weights, get_state_dict_dtype
|
||||
|
|
@ -351,12 +358,34 @@ class _BaseAutoModelClass:
|
|||
logger.info(f"Converting model, it may takes up to several minutes ...")
|
||||
from intel_npu_acceleration_library.compiler import create_npu_kernels
|
||||
|
||||
with torch.no_grad():
|
||||
optimize_llm(model)
|
||||
cls.load_convert(qtype, model, quant_device, *model_args, **kwargs)
|
||||
create_npu_kernels(model)
|
||||
if optimize_model:
|
||||
invalidInputError(
|
||||
max_prompt_len < max_output_len,
|
||||
(
|
||||
f"max_prompt_len ({max_prompt_len}) should be less"
|
||||
" than max_output_len ({max_output_len})"
|
||||
),
|
||||
)
|
||||
from ipex_llm.transformers.npu_models.convert_mp import optimize_llm_pre
|
||||
|
||||
model = model.eval()
|
||||
if hasattr(model, "llm"):
|
||||
llm = model.llm
|
||||
else:
|
||||
llm = model
|
||||
|
||||
with torch.no_grad():
|
||||
optimize_llm_pre(model, qtype)
|
||||
cls.load_convert(qtype, model, quant_device, modules_to_not_convert,
|
||||
*model_args, **kwargs)
|
||||
create_npu_kernels(llm)
|
||||
|
||||
else:
|
||||
from ipex_llm.transformers.npu_models.convert import optimize_llm
|
||||
optimize_llm(model)
|
||||
with torch.no_grad():
|
||||
cls.load_convert(qtype, model, quant_device, modules_to_not_convert,
|
||||
*model_args, **kwargs)
|
||||
create_npu_kernels(model)
|
||||
|
||||
if is_sharded:
|
||||
loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]
|
||||
|
|
@ -415,6 +444,17 @@ class _BaseAutoModelClass:
|
|||
for param in model.parameters():
|
||||
param.requires_grad_(False)
|
||||
|
||||
if optimize_model:
|
||||
from ipex_llm.transformers.npu_models.convert_mp import optimize_llm
|
||||
optimize_llm(
|
||||
llm,
|
||||
max_output_len=max_output_len,
|
||||
max_prompt_len=max_prompt_len,
|
||||
inter_pp=inter_pp,
|
||||
intra_pp=intra_pp,
|
||||
transpose_value_cache=transpose_value_cache,
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue