diff --git a/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/Pipeline-Models/llama.py b/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/Pipeline-Models/llama.py
new file mode 100644
index 00000000..25607b88
--- /dev/null
+++ b/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/Pipeline-Models/llama.py
@@ -0,0 +1,90 @@
+#
+# Copyright 2016 The BigDL Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+
+import torch
+import time
+import argparse
+from ipex_llm.transformers.npu_pipeline_model import AutoModelForCausalLM
+from transformers import AutoTokenizer
+from transformers.utils import logging
+
+logger = logging.get_logger(__name__)
+
+def get_prompt(message: str, chat_history: list[tuple[str, str]],
+ system_prompt: str) -> str:
+ texts = [f'[INST] <>\n{system_prompt}\n<>\n\n']
+ # The first user input is _not_ stripped
+ do_strip = False
+ for user_input, response in chat_history:
+ user_input = user_input.strip() if do_strip else user_input
+ do_strip = True
+ texts.append(f'{user_input} [/INST] {response.strip()} [INST] ')
+ message = message.strip() if do_strip else message
+ texts.append(f'{message} [/INST]')
+ return ''.join(texts)
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(
+ description="Predict Tokens using `generate()` API for npu model"
+ )
+ parser.add_argument(
+ "--repo-id-or-model-path",
+ type=str,
+ default=r"C:\\Llama2-converted-weights\\",
+ help="The folder path of converted model blobs",
+ )
+ parser.add_argument('--prompt', type=str, default="What is AI?",
+ help='Prompt to infer')
+ parser.add_argument("--n-predict", type=int, default=32, help="Max tokens to predict")
+ parser.add_argument("--max-output-len", type=int, default=1024)
+
+ args = parser.parse_args()
+ model_path = args.repo_id_or_model_path
+
+ model = AutoModelForCausalLM.from_pretrained(model_path,
+ ov_model=True,
+ max_output_len=args.max_output_len,
+ model_name="Model70")
+
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
+
+ DEFAULT_SYSTEM_PROMPT = """\
+ """
+
+ print("-" * 80)
+ print("done")
+ with torch.inference_mode():
+ print("finish to load")
+ prompt = get_prompt(args.prompt, [], system_prompt=DEFAULT_SYSTEM_PROMPT)
+ _input_ids = tokenizer.encode(prompt, return_tensors="pt")
+ print("input length:", len(_input_ids[0]))
+ st = time.time()
+ output = model.generate(
+ _input_ids, max_new_tokens=args.n_predict,
+ )
+ end = time.time()
+ print(f"Inference time: {end-st} s")
+ input_str = tokenizer.decode(_input_ids[0], skip_special_tokens=False)
+ print("-" * 20, "Input", "-" * 20)
+ print(input_str)
+ output_str = tokenizer.decode(output[0], skip_special_tokens=False)
+ print("-" * 20, "Output", "-" * 20)
+ print(output_str)
+
+ print("-" * 80)
+ print("done")
+ print("success shut down")
diff --git a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/__init__.py b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/__init__.py
new file mode 100644
index 00000000..50632808
--- /dev/null
+++ b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/__init__.py
@@ -0,0 +1,17 @@
+#
+# Copyright 2016 The BigDL Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from .pipeline_model import *
diff --git a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/pipeline_cpp.py b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/pipeline_cpp.py
new file mode 100644
index 00000000..4bd8f579
--- /dev/null
+++ b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/pipeline_cpp.py
@@ -0,0 +1,64 @@
+#
+# Copyright 2016 The BigDL Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import sys
+import ctypes
+import pathlib
+from ipex_llm.utils.common import invalidInputError
+
+
+def get_shared_lib_info(lib_base_name: str):
+ # Determine the file extension based on the platform
+ if sys.platform.startswith("linux") or sys.platform == "darwin":
+ lib_ext = ".so"
+ elif sys.platform == "win32":
+ lib_ext = ".dll"
+ else:
+ invalidInputError(False, "Unsupported platform.")
+
+ # Construct the paths to the possible shared library names (python/llm/src/ipex-llm/llm/libs)
+ _base_path = pathlib.Path(__file__).parent.parent.parent.resolve()
+ _base_path = _base_path / 'libs'
+
+ lib_path = os.path.join(_base_path, lib_base_name + lib_ext)
+
+ return _base_path, lib_path
+
+_, _lib_path = get_shared_lib_info("pipeline")
+
+# Load the library
+_lib = ctypes.cdll.LoadLibrary(_lib_path)
+
+_lib.InitLLMPipeline.argtypes = [ctypes.c_int] * 5 + [ctypes.c_char_p] * 5
+_lib.InitLLMPipeline.restype = ctypes.c_int
+
+_lib.generate_serve.argtypes = [ctypes.c_int] * 5
+_lib.generate_serve.restype = ctypes.c_int
+
+
+def InitLLMPipeline(kv_len: int, num_head: int, head_dim: int, num_layers: int, vocab_size: int,
+ model_weight_dir: str, model_name: str,
+ first_blob_name: str, last_blob_name: str, rest_blob_name: str):
+ return _lib.InitLLMPipeline(kv_len, num_head, head_dim, num_layers, vocab_size,
+ model_weight_dir.encode('utf-8'), model_name.encode('utf-8'),
+ first_blob_name.encode('utf-8'), last_blob_name.encode('utf-8'),
+ rest_blob_name.encode('utf-8'))
+
+
+def generate_serve(kv_len: int, num_head: int, head_dim: int, num_layers: int,
+ param_n_output: int):
+ _lib.generate_serve(kv_len, num_head, head_dim, num_layers, param_n_output)
diff --git a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/pipeline_model.py b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/pipeline_model.py
new file mode 100644
index 00000000..25bfd30a
--- /dev/null
+++ b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/pipeline_model.py
@@ -0,0 +1,246 @@
+#
+# Copyright 2016 The BigDL Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import time
+import numpy
+import warnings
+import torch
+import sys
+import transformers
+from typing import List
+from unittest.mock import patch
+from transformers.dynamic_module_utils import get_imports
+from .pipeline_cpp import InitLLMPipeline, generate_serve
+from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
+from transformers import GenerationConfig, \
+ LogitsProcessorList, StoppingCriteriaList
+import threading
+from ipex_llm.utils.common import invalidInputError
+import os
+from transformers import PretrainedConfig
+
+
+def patch_flash_attn_import(filename: str) -> List[str]:
+ """Work around for https://huggingface.co/microsoft/phi-1_5/discussions/72."""
+ imports = get_imports(filename)
+ if "flash_attn" in imports:
+ imports.remove("flash_attn")
+ return imports
+
+
+def ignore_argument(kwargs: dict, key: "str"):
+ arg = kwargs.pop(key, None)
+ if arg is not None:
+ warnings.warn(f"argument `{key}={arg}` will be ignored")
+
+
+def generate(
+ self,
+ inputs: Optional[torch.Tensor] = None,
+ generation_config: Optional[GenerationConfig] = None,
+ logits_processor: Optional[LogitsProcessorList] = None,
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
+ prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]]=None,
+ synced_gpus: Optional[bool] = None,
+ assistant_model: Optional["PreTrainedModel"] = None,
+ streamer: Optional["BaseStreamer"] = None,
+ **kwargs,
+):
+ new_generate_kwargs = {}
+ for var in ['max_new_tokens', 'attention_mask', 'eos_token_id']:
+ value = kwargs.pop(var, None)
+ if value is not None:
+ new_generate_kwargs[var] = value
+
+ if isinstance(inputs[0], torch.Tensor):
+ numpy_input = inputs[0].numpy()
+ else:
+ numpy_input = inputs[0]
+ input_length = numpy.size(numpy_input)
+
+ new_tokens = new_generate_kwargs['max_new_tokens']
+ invalidInputError(input_length + new_tokens <= self.kv_len + 1,
+ "Input plus output tokens should not exceed max_output_len.")
+
+ # start generate_serve by Thread
+ thread = threading.Thread(target=generate_serve,
+ args=(self.kv_len, self.num_head,
+ self.head_dim, self.num_layers,
+ new_tokens))
+ thread.start()
+
+ in_pipe_path = "\\\\.\\pipe\\llminputpipe"
+ out_pipe_path = "\\\\.\\pipe\\llmoutputpipe"
+
+ while True:
+ try:
+ input_pipe = open(in_pipe_path, "wb")
+ except:
+ print('Waiting for input pipe')
+ time.sleep(1)
+ else:
+ break
+
+ while True:
+ try:
+ output_pipe = open(out_pipe_path, "rb")
+ except:
+ print('Waiting for output pipe')
+ time.sleep(1)
+ else:
+ break
+
+ bdata = b''
+ for i in range(0, input_length):
+ d = int(numpy_input[i])
+ bdata = bdata + d.to_bytes(4, sys.byteorder)
+
+ if "eos_token_id" not in new_generate_kwargs:
+ eos = 0xffffffff
+ else:
+ eos = new_generate_kwargs["eos_token_id"]
+
+ bdata = bdata + eos.to_bytes(4, sys.byteorder)
+
+ input_pipe.write(bytearray(bdata))
+ input_pipe.flush()
+
+ buffersize = 4
+ output_tokens = []
+ while True:
+ data = output_pipe.read(buffersize)
+ if len(data) == 0:
+ break
+ token = int.from_bytes(data, sys.byteorder)
+ output_tokens.append(torch.tensor([token]))
+ if streamer is not None:
+ streamer.put(torch.tensor([token]))
+ if token == eos:
+ break
+
+ output = torch.stack(output_tokens, dim=1)
+ if streamer is not None:
+ streamer.end()
+
+ thread.join()
+ return output
+
+
+class NPUModel():
+ def __init__(self):
+ pass
+
+
+class _BaseAutoModelClass:
+ HF_MODEL = None
+
+ @classmethod
+ @patch("transformers.dynamic_module_utils.get_imports", patch_flash_attn_import)
+ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
+ """
+ Load a model from a directory or the HF Hub.
+ The loaded model will run supported OPs on NPU, then run other OPs on CPU.
+
+ Three new arguments are added to extend Hugging Face's from_pretrained method as follows:
+ :param ov_model: boolean value, whether load blob files from specified directory.
+ If it's False, will convert HF model to specified blob format,
+ but which is not supported now. Default to True.
+ :param max_output_len: Maximum context length for whole generation, default to 1024.
+ :param model_name: Name prefix of the model weight bin file.
+ :return: a model instance
+ """
+ ov_model = kwargs.get("ov_model", True)
+ max_output_len = kwargs.pop("max_output_len", 1024)
+
+ invalidInputError(ov_model,
+ "Original HF model is not supported now.")
+ invalidInputError(os.path.exists(pretrained_model_name_or_path),
+ "This directory does not exist, please double check it.")
+
+ config_json = os.path.join(pretrained_model_name_or_path, "config.json")
+ invalidInputError(os.path.exists(config_json),
+ "config.json is not found in current directory, please double check it.")
+ config = PretrainedConfig.from_json_file(config_json)
+ model = NPUModel()
+ model.kv_len = max_output_len - 1
+ model.num_head = config.num_attention_heads
+ model.head_dim = config.hidden_size // config.num_attention_heads
+ model.num_layers = config.num_hidden_layers
+ model.vocab_size = config.vocab_size
+
+ model_weight_dir = os.path.join(pretrained_model_name_or_path, "model_layer_weights")
+ model_name = kwargs.get("model_name", "Model")
+ first_blob_name = os.path.join(pretrained_model_name_or_path, "first_model.blob")
+ last_blob_name = os.path.join(pretrained_model_name_or_path, "last_model.blob")
+ rest_blob_name = os.path.join(pretrained_model_name_or_path, "rest_model.blob")
+
+ for path in [model_weight_dir, first_blob_name, last_blob_name, rest_blob_name]:
+ invalidInputError(os.path.exists(path),
+ f"{path} is not found in current directory, please double check it.")
+
+ try:
+ res = InitLLMPipeline(model.kv_len, model.num_head, model.head_dim, model.num_layers,
+ model.vocab_size, model_weight_dir, model_name,
+ first_blob_name, last_blob_name, rest_blob_name)
+ except:
+ invalidInputError(False,
+ "False to InitLLMPipeline.")
+ exit(0)
+
+ # patch generate function
+ import types
+ model.generate = types.MethodType(generate, model)
+ return model
+
+
+class AutoModelForCausalLM(_BaseAutoModelClass):
+ HF_Model = transformers.AutoModelForCausalLM
+
+
+class AutoModel(_BaseAutoModelClass):
+ HF_Model = transformers.AutoModel
+
+
+class AutoModelForSpeechSeq2Seq(_BaseAutoModelClass):
+ HF_Model = transformers.AutoModelForSpeechSeq2Seq
+
+
+class AutoModelForSeq2SeqLM(_BaseAutoModelClass):
+ HF_Model = transformers.AutoModelForSeq2SeqLM
+
+
+class AutoModelForSequenceClassification(_BaseAutoModelClass):
+ HF_Model = transformers.AutoModelForSequenceClassification
+
+
+class AutoModelForMaskedLM(_BaseAutoModelClass):
+ HF_Model = transformers.AutoModelForMaskedLM
+
+
+class AutoModelForQuestionAnswering(_BaseAutoModelClass):
+ HF_Model = transformers.AutoModelForQuestionAnswering
+
+
+class AutoModelForNextSentencePrediction(_BaseAutoModelClass):
+ HF_Model = transformers.AutoModelForNextSentencePrediction
+
+
+class AutoModelForMultipleChoice(_BaseAutoModelClass):
+ HF_Model = transformers.AutoModelForMultipleChoice
+
+
+class AutoModelForTokenClassification(_BaseAutoModelClass):
+ HF_Model = transformers.AutoModelForTokenClassification