diff --git a/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/Pipeline-Models/llama.py b/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/Pipeline-Models/llama.py new file mode 100644 index 00000000..25607b88 --- /dev/null +++ b/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/Pipeline-Models/llama.py @@ -0,0 +1,90 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +import torch +import time +import argparse +from ipex_llm.transformers.npu_pipeline_model import AutoModelForCausalLM +from transformers import AutoTokenizer +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +def get_prompt(message: str, chat_history: list[tuple[str, str]], + system_prompt: str) -> str: + texts = [f'[INST] <>\n{system_prompt}\n<>\n\n'] + # The first user input is _not_ stripped + do_strip = False + for user_input, response in chat_history: + user_input = user_input.strip() if do_strip else user_input + do_strip = True + texts.append(f'{user_input} [/INST] {response.strip()} [INST] ') + message = message.strip() if do_strip else message + texts.append(f'{message} [/INST]') + return ''.join(texts) + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Predict Tokens using `generate()` API for npu model" + ) + parser.add_argument( + "--repo-id-or-model-path", + type=str, + default=r"C:\\Llama2-converted-weights\\", + help="The folder path of converted model blobs", + ) + parser.add_argument('--prompt', type=str, default="What is AI?", + help='Prompt to infer') + parser.add_argument("--n-predict", type=int, default=32, help="Max tokens to predict") + parser.add_argument("--max-output-len", type=int, default=1024) + + args = parser.parse_args() + model_path = args.repo_id_or_model_path + + model = AutoModelForCausalLM.from_pretrained(model_path, + ov_model=True, + max_output_len=args.max_output_len, + model_name="Model70") + + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + + DEFAULT_SYSTEM_PROMPT = """\ + """ + + print("-" * 80) + print("done") + with torch.inference_mode(): + print("finish to load") + prompt = get_prompt(args.prompt, [], system_prompt=DEFAULT_SYSTEM_PROMPT) + _input_ids = tokenizer.encode(prompt, return_tensors="pt") + print("input length:", len(_input_ids[0])) + st = time.time() + output = model.generate( + _input_ids, max_new_tokens=args.n_predict, + ) + end = time.time() + print(f"Inference time: {end-st} s") + input_str = tokenizer.decode(_input_ids[0], skip_special_tokens=False) + print("-" * 20, "Input", "-" * 20) + print(input_str) + output_str = tokenizer.decode(output[0], skip_special_tokens=False) + print("-" * 20, "Output", "-" * 20) + print(output_str) + + print("-" * 80) + print("done") + print("success shut down") diff --git a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/__init__.py b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/__init__.py new file mode 100644 index 00000000..50632808 --- /dev/null +++ b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/__init__.py @@ -0,0 +1,17 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from .pipeline_model import * diff --git a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/pipeline_cpp.py b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/pipeline_cpp.py new file mode 100644 index 00000000..4bd8f579 --- /dev/null +++ b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/pipeline_cpp.py @@ -0,0 +1,64 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import sys +import ctypes +import pathlib +from ipex_llm.utils.common import invalidInputError + + +def get_shared_lib_info(lib_base_name: str): + # Determine the file extension based on the platform + if sys.platform.startswith("linux") or sys.platform == "darwin": + lib_ext = ".so" + elif sys.platform == "win32": + lib_ext = ".dll" + else: + invalidInputError(False, "Unsupported platform.") + + # Construct the paths to the possible shared library names (python/llm/src/ipex-llm/llm/libs) + _base_path = pathlib.Path(__file__).parent.parent.parent.resolve() + _base_path = _base_path / 'libs' + + lib_path = os.path.join(_base_path, lib_base_name + lib_ext) + + return _base_path, lib_path + +_, _lib_path = get_shared_lib_info("pipeline") + +# Load the library +_lib = ctypes.cdll.LoadLibrary(_lib_path) + +_lib.InitLLMPipeline.argtypes = [ctypes.c_int] * 5 + [ctypes.c_char_p] * 5 +_lib.InitLLMPipeline.restype = ctypes.c_int + +_lib.generate_serve.argtypes = [ctypes.c_int] * 5 +_lib.generate_serve.restype = ctypes.c_int + + +def InitLLMPipeline(kv_len: int, num_head: int, head_dim: int, num_layers: int, vocab_size: int, + model_weight_dir: str, model_name: str, + first_blob_name: str, last_blob_name: str, rest_blob_name: str): + return _lib.InitLLMPipeline(kv_len, num_head, head_dim, num_layers, vocab_size, + model_weight_dir.encode('utf-8'), model_name.encode('utf-8'), + first_blob_name.encode('utf-8'), last_blob_name.encode('utf-8'), + rest_blob_name.encode('utf-8')) + + +def generate_serve(kv_len: int, num_head: int, head_dim: int, num_layers: int, + param_n_output: int): + _lib.generate_serve(kv_len, num_head, head_dim, num_layers, param_n_output) diff --git a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/pipeline_model.py b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/pipeline_model.py new file mode 100644 index 00000000..25bfd30a --- /dev/null +++ b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/pipeline_model.py @@ -0,0 +1,246 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import time +import numpy +import warnings +import torch +import sys +import transformers +from typing import List +from unittest.mock import patch +from transformers.dynamic_module_utils import get_imports +from .pipeline_cpp import InitLLMPipeline, generate_serve +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union +from transformers import GenerationConfig, \ + LogitsProcessorList, StoppingCriteriaList +import threading +from ipex_llm.utils.common import invalidInputError +import os +from transformers import PretrainedConfig + + +def patch_flash_attn_import(filename: str) -> List[str]: + """Work around for https://huggingface.co/microsoft/phi-1_5/discussions/72.""" + imports = get_imports(filename) + if "flash_attn" in imports: + imports.remove("flash_attn") + return imports + + +def ignore_argument(kwargs: dict, key: "str"): + arg = kwargs.pop(key, None) + if arg is not None: + warnings.warn(f"argument `{key}={arg}` will be ignored") + + +def generate( + self, + inputs: Optional[torch.Tensor] = None, + generation_config: Optional[GenerationConfig] = None, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]]=None, + synced_gpus: Optional[bool] = None, + assistant_model: Optional["PreTrainedModel"] = None, + streamer: Optional["BaseStreamer"] = None, + **kwargs, +): + new_generate_kwargs = {} + for var in ['max_new_tokens', 'attention_mask', 'eos_token_id']: + value = kwargs.pop(var, None) + if value is not None: + new_generate_kwargs[var] = value + + if isinstance(inputs[0], torch.Tensor): + numpy_input = inputs[0].numpy() + else: + numpy_input = inputs[0] + input_length = numpy.size(numpy_input) + + new_tokens = new_generate_kwargs['max_new_tokens'] + invalidInputError(input_length + new_tokens <= self.kv_len + 1, + "Input plus output tokens should not exceed max_output_len.") + + # start generate_serve by Thread + thread = threading.Thread(target=generate_serve, + args=(self.kv_len, self.num_head, + self.head_dim, self.num_layers, + new_tokens)) + thread.start() + + in_pipe_path = "\\\\.\\pipe\\llminputpipe" + out_pipe_path = "\\\\.\\pipe\\llmoutputpipe" + + while True: + try: + input_pipe = open(in_pipe_path, "wb") + except: + print('Waiting for input pipe') + time.sleep(1) + else: + break + + while True: + try: + output_pipe = open(out_pipe_path, "rb") + except: + print('Waiting for output pipe') + time.sleep(1) + else: + break + + bdata = b'' + for i in range(0, input_length): + d = int(numpy_input[i]) + bdata = bdata + d.to_bytes(4, sys.byteorder) + + if "eos_token_id" not in new_generate_kwargs: + eos = 0xffffffff + else: + eos = new_generate_kwargs["eos_token_id"] + + bdata = bdata + eos.to_bytes(4, sys.byteorder) + + input_pipe.write(bytearray(bdata)) + input_pipe.flush() + + buffersize = 4 + output_tokens = [] + while True: + data = output_pipe.read(buffersize) + if len(data) == 0: + break + token = int.from_bytes(data, sys.byteorder) + output_tokens.append(torch.tensor([token])) + if streamer is not None: + streamer.put(torch.tensor([token])) + if token == eos: + break + + output = torch.stack(output_tokens, dim=1) + if streamer is not None: + streamer.end() + + thread.join() + return output + + +class NPUModel(): + def __init__(self): + pass + + +class _BaseAutoModelClass: + HF_MODEL = None + + @classmethod + @patch("transformers.dynamic_module_utils.get_imports", patch_flash_attn_import) + def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): + """ + Load a model from a directory or the HF Hub. + The loaded model will run supported OPs on NPU, then run other OPs on CPU. + + Three new arguments are added to extend Hugging Face's from_pretrained method as follows: + :param ov_model: boolean value, whether load blob files from specified directory. + If it's False, will convert HF model to specified blob format, + but which is not supported now. Default to True. + :param max_output_len: Maximum context length for whole generation, default to 1024. + :param model_name: Name prefix of the model weight bin file. + :return: a model instance + """ + ov_model = kwargs.get("ov_model", True) + max_output_len = kwargs.pop("max_output_len", 1024) + + invalidInputError(ov_model, + "Original HF model is not supported now.") + invalidInputError(os.path.exists(pretrained_model_name_or_path), + "This directory does not exist, please double check it.") + + config_json = os.path.join(pretrained_model_name_or_path, "config.json") + invalidInputError(os.path.exists(config_json), + "config.json is not found in current directory, please double check it.") + config = PretrainedConfig.from_json_file(config_json) + model = NPUModel() + model.kv_len = max_output_len - 1 + model.num_head = config.num_attention_heads + model.head_dim = config.hidden_size // config.num_attention_heads + model.num_layers = config.num_hidden_layers + model.vocab_size = config.vocab_size + + model_weight_dir = os.path.join(pretrained_model_name_or_path, "model_layer_weights") + model_name = kwargs.get("model_name", "Model") + first_blob_name = os.path.join(pretrained_model_name_or_path, "first_model.blob") + last_blob_name = os.path.join(pretrained_model_name_or_path, "last_model.blob") + rest_blob_name = os.path.join(pretrained_model_name_or_path, "rest_model.blob") + + for path in [model_weight_dir, first_blob_name, last_blob_name, rest_blob_name]: + invalidInputError(os.path.exists(path), + f"{path} is not found in current directory, please double check it.") + + try: + res = InitLLMPipeline(model.kv_len, model.num_head, model.head_dim, model.num_layers, + model.vocab_size, model_weight_dir, model_name, + first_blob_name, last_blob_name, rest_blob_name) + except: + invalidInputError(False, + "False to InitLLMPipeline.") + exit(0) + + # patch generate function + import types + model.generate = types.MethodType(generate, model) + return model + + +class AutoModelForCausalLM(_BaseAutoModelClass): + HF_Model = transformers.AutoModelForCausalLM + + +class AutoModel(_BaseAutoModelClass): + HF_Model = transformers.AutoModel + + +class AutoModelForSpeechSeq2Seq(_BaseAutoModelClass): + HF_Model = transformers.AutoModelForSpeechSeq2Seq + + +class AutoModelForSeq2SeqLM(_BaseAutoModelClass): + HF_Model = transformers.AutoModelForSeq2SeqLM + + +class AutoModelForSequenceClassification(_BaseAutoModelClass): + HF_Model = transformers.AutoModelForSequenceClassification + + +class AutoModelForMaskedLM(_BaseAutoModelClass): + HF_Model = transformers.AutoModelForMaskedLM + + +class AutoModelForQuestionAnswering(_BaseAutoModelClass): + HF_Model = transformers.AutoModelForQuestionAnswering + + +class AutoModelForNextSentencePrediction(_BaseAutoModelClass): + HF_Model = transformers.AutoModelForNextSentencePrediction + + +class AutoModelForMultipleChoice(_BaseAutoModelClass): + HF_Model = transformers.AutoModelForMultipleChoice + + +class AutoModelForTokenClassification(_BaseAutoModelClass): + HF_Model = transformers.AutoModelForTokenClassification