add save & load support for NPU optimized model (#11999)
* add save & load support * fix style
This commit is contained in:
parent
6eb55653ba
commit
9eaff5e47d
1 changed files with 46 additions and 6 deletions
|
|
@ -174,6 +174,7 @@ class _BaseAutoModelClass:
|
||||||
intra_pp=intra_pp,
|
intra_pp=intra_pp,
|
||||||
transpose_value_cache=transpose_value_cache,
|
transpose_value_cache=transpose_value_cache,
|
||||||
)
|
)
|
||||||
|
model.save_low_bit = types.MethodType(save_low_bit, model)
|
||||||
else:
|
else:
|
||||||
from ipex_llm.transformers.npu_models.convert import optimize_llm
|
from ipex_llm.transformers.npu_models.convert import optimize_llm
|
||||||
optimize_llm(model)
|
optimize_llm(model)
|
||||||
|
|
@ -209,10 +210,16 @@ class _BaseAutoModelClass:
|
||||||
ignore_argument(kwargs, "lightweight_bmm")
|
ignore_argument(kwargs, "lightweight_bmm")
|
||||||
ignore_argument(kwargs, "cpu_embedding")
|
ignore_argument(kwargs, "cpu_embedding")
|
||||||
ignore_argument(kwargs, "embedding_qtype")
|
ignore_argument(kwargs, "embedding_qtype")
|
||||||
ignore_argument(kwargs, "optimize_model")
|
|
||||||
ignore_argument(kwargs, "modules_to_not_convert")
|
ignore_argument(kwargs, "modules_to_not_convert")
|
||||||
ignore_argument(kwargs, "speculative")
|
ignore_argument(kwargs, "speculative")
|
||||||
ignore_argument(kwargs, "pipeline_parallel_stages")
|
ignore_argument(kwargs, "pipeline_parallel_stages")
|
||||||
|
optimize_model = kwargs.pop("optimize_model", False)
|
||||||
|
max_output_len = kwargs.pop("max_output_len", 1024)
|
||||||
|
max_prompt_len = kwargs.pop("max_prompt_len", 512)
|
||||||
|
inter_pp = kwargs.pop("inter_pp", None)
|
||||||
|
intra_pp = kwargs.pop("intra_pp", None)
|
||||||
|
transpose_value_cache = kwargs.pop("transpose_value_cache", True)
|
||||||
|
modules_to_not_convert = kwargs.pop("modules_to_not_convert", [])
|
||||||
|
|
||||||
from transformers.models.auto.configuration_auto import AutoConfig
|
from transformers.models.auto.configuration_auto import AutoConfig
|
||||||
from transformers.modeling_utils import no_init_weights, get_state_dict_dtype
|
from transformers.modeling_utils import no_init_weights, get_state_dict_dtype
|
||||||
|
|
@ -351,12 +358,34 @@ class _BaseAutoModelClass:
|
||||||
logger.info(f"Converting model, it may takes up to several minutes ...")
|
logger.info(f"Converting model, it may takes up to several minutes ...")
|
||||||
from intel_npu_acceleration_library.compiler import create_npu_kernels
|
from intel_npu_acceleration_library.compiler import create_npu_kernels
|
||||||
|
|
||||||
with torch.no_grad():
|
if optimize_model:
|
||||||
optimize_llm(model)
|
invalidInputError(
|
||||||
cls.load_convert(qtype, model, quant_device, *model_args, **kwargs)
|
max_prompt_len < max_output_len,
|
||||||
create_npu_kernels(model)
|
(
|
||||||
|
f"max_prompt_len ({max_prompt_len}) should be less"
|
||||||
|
" than max_output_len ({max_output_len})"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
from ipex_llm.transformers.npu_models.convert_mp import optimize_llm_pre
|
||||||
|
|
||||||
model = model.eval()
|
if hasattr(model, "llm"):
|
||||||
|
llm = model.llm
|
||||||
|
else:
|
||||||
|
llm = model
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
optimize_llm_pre(model, qtype)
|
||||||
|
cls.load_convert(qtype, model, quant_device, modules_to_not_convert,
|
||||||
|
*model_args, **kwargs)
|
||||||
|
create_npu_kernels(llm)
|
||||||
|
|
||||||
|
else:
|
||||||
|
from ipex_llm.transformers.npu_models.convert import optimize_llm
|
||||||
|
optimize_llm(model)
|
||||||
|
with torch.no_grad():
|
||||||
|
cls.load_convert(qtype, model, quant_device, modules_to_not_convert,
|
||||||
|
*model_args, **kwargs)
|
||||||
|
create_npu_kernels(model)
|
||||||
|
|
||||||
if is_sharded:
|
if is_sharded:
|
||||||
loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]
|
loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]
|
||||||
|
|
@ -415,6 +444,17 @@ class _BaseAutoModelClass:
|
||||||
for param in model.parameters():
|
for param in model.parameters():
|
||||||
param.requires_grad_(False)
|
param.requires_grad_(False)
|
||||||
|
|
||||||
|
if optimize_model:
|
||||||
|
from ipex_llm.transformers.npu_models.convert_mp import optimize_llm
|
||||||
|
optimize_llm(
|
||||||
|
llm,
|
||||||
|
max_output_len=max_output_len,
|
||||||
|
max_prompt_len=max_prompt_len,
|
||||||
|
inter_pp=inter_pp,
|
||||||
|
intra_pp=intra_pp,
|
||||||
|
transpose_value_cache=transpose_value_cache,
|
||||||
|
)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue