refactor from_pretrained API for NPU (#11927)
This commit is contained in:
parent
7ca557aada
commit
6c3eb1e1e8
4 changed files with 19 additions and 9 deletions
|
|
@ -68,7 +68,7 @@ if __name__ == "__main__":
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
attn_implementation="eager",
|
attn_implementation="eager",
|
||||||
load_in_low_bit="sym_int4",
|
load_in_low_bit="sym_int4",
|
||||||
enable_mp=True,
|
optimize_model=True,
|
||||||
max_output_len=args.max_output_len,
|
max_output_len=args.max_output_len,
|
||||||
max_prompt_len=args.max_prompt_len,
|
max_prompt_len=args.max_prompt_len,
|
||||||
intra_pp=args.intra_pp,
|
intra_pp=args.intra_pp,
|
||||||
|
|
|
||||||
|
|
@ -55,7 +55,7 @@ if __name__ == "__main__":
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
attn_implementation="eager",
|
attn_implementation="eager",
|
||||||
load_in_low_bit="sym_int4",
|
load_in_low_bit="sym_int4",
|
||||||
enable_mp=True,
|
optimize_model=True,
|
||||||
max_output_len=args.max_output_len,
|
max_output_len=args.max_output_len,
|
||||||
max_prompt_len=args.max_prompt_len,
|
max_prompt_len=args.max_prompt_len,
|
||||||
intra_pp=args.intra_pp,
|
intra_pp=args.intra_pp,
|
||||||
|
|
|
||||||
|
|
@ -110,16 +110,16 @@ class _BaseAutoModelClass:
|
||||||
ignore_argument(kwargs, "mixed_precision")
|
ignore_argument(kwargs, "mixed_precision")
|
||||||
ignore_argument(kwargs, "cpu_embedding")
|
ignore_argument(kwargs, "cpu_embedding")
|
||||||
ignore_argument(kwargs, "embedding_qtype")
|
ignore_argument(kwargs, "embedding_qtype")
|
||||||
ignore_argument(kwargs, "optimize_model")
|
ignore_argument(kwargs, "enable_mp")
|
||||||
ignore_argument(kwargs, "modules_to_not_convert")
|
ignore_argument(kwargs, "modules_to_not_convert")
|
||||||
ignore_argument(kwargs, "quantization_config")
|
ignore_argument(kwargs, "quantization_config")
|
||||||
ignore_argument(kwargs, "speculative")
|
ignore_argument(kwargs, "speculative")
|
||||||
ignore_argument(kwargs, "pipeline_parallel_stages")
|
ignore_argument(kwargs, "pipeline_parallel_stages")
|
||||||
enable_mp = kwargs.pop("enable_mp", False)
|
optimize_model = kwargs.pop("optimize_model", False)
|
||||||
max_output_len = kwargs.pop("max_output_len", 1024)
|
max_output_len = kwargs.pop("max_output_len", 1024)
|
||||||
max_prompt_len = kwargs.pop("max_prompt_len", max_output_len)
|
max_prompt_len = kwargs.pop("max_prompt_len", max_output_len)
|
||||||
inter_pp = kwargs.pop("inter_pp", 2)
|
inter_pp = kwargs.pop("inter_pp", None)
|
||||||
intra_pp = kwargs.pop("intra_pp", 2)
|
intra_pp = kwargs.pop("intra_pp", None)
|
||||||
transpose_value_cache = kwargs.pop("transpose_value_cache", True)
|
transpose_value_cache = kwargs.pop("transpose_value_cache", True)
|
||||||
|
|
||||||
_args = copy.deepcopy(args)
|
_args = copy.deepcopy(args)
|
||||||
|
|
@ -140,7 +140,7 @@ class _BaseAutoModelClass:
|
||||||
logger.info(f"Converting model, it may takes up to several minutes ...")
|
logger.info(f"Converting model, it may takes up to several minutes ...")
|
||||||
from intel_npu_acceleration_library.compiler import create_npu_kernels
|
from intel_npu_acceleration_library.compiler import create_npu_kernels
|
||||||
|
|
||||||
if enable_mp:
|
if optimize_model:
|
||||||
invalidInputError(
|
invalidInputError(
|
||||||
max_prompt_len < max_output_len,
|
max_prompt_len < max_output_len,
|
||||||
(
|
(
|
||||||
|
|
|
||||||
|
|
@ -29,11 +29,16 @@ def optimize_llm(
|
||||||
model: torch.nn.Module,
|
model: torch.nn.Module,
|
||||||
max_output_len=1024,
|
max_output_len=1024,
|
||||||
max_prompt_len=1024,
|
max_prompt_len=1024,
|
||||||
inter_pp=2,
|
inter_pp=None,
|
||||||
intra_pp=2,
|
intra_pp=None,
|
||||||
transpose_value_cache=True,
|
transpose_value_cache=True,
|
||||||
):
|
):
|
||||||
if model.config.model_type == "llama":
|
if model.config.model_type == "llama":
|
||||||
|
if intra_pp is None:
|
||||||
|
intra_pp = 2
|
||||||
|
if inter_pp is None:
|
||||||
|
inter_pp = 2
|
||||||
|
|
||||||
from ipex_llm.transformers.npu_models.llama_mp import gen_llama_fused_model_forward
|
from ipex_llm.transformers.npu_models.llama_mp import gen_llama_fused_model_forward
|
||||||
from ipex_llm.transformers.npu_models.llama_mp import DecodeRunner, PrefillRunner
|
from ipex_llm.transformers.npu_models.llama_mp import DecodeRunner, PrefillRunner
|
||||||
from transformers.models.llama.modeling_llama import LlamaModel
|
from transformers.models.llama.modeling_llama import LlamaModel
|
||||||
|
|
@ -60,6 +65,11 @@ def optimize_llm(
|
||||||
convert_forward(model, LlamaForCausalLM, llama2_casullm_forward)
|
convert_forward(model, LlamaForCausalLM, llama2_casullm_forward)
|
||||||
elif model.config.model_type == "qwen2" and model.config.intermediate_size == 8960:
|
elif model.config.model_type == "qwen2" and model.config.intermediate_size == 8960:
|
||||||
# for qwen2-1.5B
|
# for qwen2-1.5B
|
||||||
|
if intra_pp is None:
|
||||||
|
intra_pp = 2
|
||||||
|
if inter_pp is None:
|
||||||
|
inter_pp = 1
|
||||||
|
|
||||||
from ipex_llm.transformers.npu_models.qwen2_mp import gen_qwen2_fused_model_forward
|
from ipex_llm.transformers.npu_models.qwen2_mp import gen_qwen2_fused_model_forward
|
||||||
from ipex_llm.transformers.npu_models.qwen2_mp import DecodeRunner, PrefillRunner
|
from ipex_llm.transformers.npu_models.qwen2_mp import DecodeRunner, PrefillRunner
|
||||||
from transformers.models.qwen2.modeling_qwen2 import Qwen2Model
|
from transformers.models.qwen2.modeling_qwen2 import Qwen2Model
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue