Initial support of NPU level0 Model (#12177)

* first commit to support load dll and init llm pipeline

* add init generate

* fix style

* small updates

* fix style and check tokens number
This commit is contained in:
Ruonan Wang 2024-10-11 09:45:53 +08:00 committed by GitHub
parent ac44e98b7d
commit 4d93bb81fe
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 417 additions and 0 deletions

View file

@ -0,0 +1,90 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import torch
import time
import argparse
from ipex_llm.transformers.npu_pipeline_model import AutoModelForCausalLM
from transformers import AutoTokenizer
from transformers.utils import logging
logger = logging.get_logger(__name__)
def get_prompt(message: str, chat_history: list[tuple[str, str]],
system_prompt: str) -> str:
texts = [f'<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n']
# The first user input is _not_ stripped
do_strip = False
for user_input, response in chat_history:
user_input = user_input.strip() if do_strip else user_input
do_strip = True
texts.append(f'{user_input} [/INST] {response.strip()} </s><s>[INST] ')
message = message.strip() if do_strip else message
texts.append(f'{message} [/INST]')
return ''.join(texts)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Predict Tokens using `generate()` API for npu model"
)
parser.add_argument(
"--repo-id-or-model-path",
type=str,
default=r"C:\\Llama2-converted-weights\\",
help="The folder path of converted model blobs",
)
parser.add_argument('--prompt', type=str, default="What is AI?",
help='Prompt to infer')
parser.add_argument("--n-predict", type=int, default=32, help="Max tokens to predict")
parser.add_argument("--max-output-len", type=int, default=1024)
args = parser.parse_args()
model_path = args.repo_id_or_model_path
model = AutoModelForCausalLM.from_pretrained(model_path,
ov_model=True,
max_output_len=args.max_output_len,
model_name="Model70")
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
DEFAULT_SYSTEM_PROMPT = """\
"""
print("-" * 80)
print("done")
with torch.inference_mode():
print("finish to load")
prompt = get_prompt(args.prompt, [], system_prompt=DEFAULT_SYSTEM_PROMPT)
_input_ids = tokenizer.encode(prompt, return_tensors="pt")
print("input length:", len(_input_ids[0]))
st = time.time()
output = model.generate(
_input_ids, max_new_tokens=args.n_predict,
)
end = time.time()
print(f"Inference time: {end-st} s")
input_str = tokenizer.decode(_input_ids[0], skip_special_tokens=False)
print("-" * 20, "Input", "-" * 20)
print(input_str)
output_str = tokenizer.decode(output[0], skip_special_tokens=False)
print("-" * 20, "Output", "-" * 20)
print(output_str)
print("-" * 80)
print("done")
print("success shut down")

View file

@ -0,0 +1,17 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from .pipeline_model import *

View file

@ -0,0 +1,64 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import sys
import ctypes
import pathlib
from ipex_llm.utils.common import invalidInputError
def get_shared_lib_info(lib_base_name: str):
# Determine the file extension based on the platform
if sys.platform.startswith("linux") or sys.platform == "darwin":
lib_ext = ".so"
elif sys.platform == "win32":
lib_ext = ".dll"
else:
invalidInputError(False, "Unsupported platform.")
# Construct the paths to the possible shared library names (python/llm/src/ipex-llm/llm/libs)
_base_path = pathlib.Path(__file__).parent.parent.parent.resolve()
_base_path = _base_path / 'libs'
lib_path = os.path.join(_base_path, lib_base_name + lib_ext)
return _base_path, lib_path
_, _lib_path = get_shared_lib_info("pipeline")
# Load the library
_lib = ctypes.cdll.LoadLibrary(_lib_path)
_lib.InitLLMPipeline.argtypes = [ctypes.c_int] * 5 + [ctypes.c_char_p] * 5
_lib.InitLLMPipeline.restype = ctypes.c_int
_lib.generate_serve.argtypes = [ctypes.c_int] * 5
_lib.generate_serve.restype = ctypes.c_int
def InitLLMPipeline(kv_len: int, num_head: int, head_dim: int, num_layers: int, vocab_size: int,
model_weight_dir: str, model_name: str,
first_blob_name: str, last_blob_name: str, rest_blob_name: str):
return _lib.InitLLMPipeline(kv_len, num_head, head_dim, num_layers, vocab_size,
model_weight_dir.encode('utf-8'), model_name.encode('utf-8'),
first_blob_name.encode('utf-8'), last_blob_name.encode('utf-8'),
rest_blob_name.encode('utf-8'))
def generate_serve(kv_len: int, num_head: int, head_dim: int, num_layers: int,
param_n_output: int):
_lib.generate_serve(kv_len, num_head, head_dim, num_layers, param_n_output)

View file

@ -0,0 +1,246 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import time
import numpy
import warnings
import torch
import sys
import transformers
from typing import List
from unittest.mock import patch
from transformers.dynamic_module_utils import get_imports
from .pipeline_cpp import InitLLMPipeline, generate_serve
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
from transformers import GenerationConfig, \
LogitsProcessorList, StoppingCriteriaList
import threading
from ipex_llm.utils.common import invalidInputError
import os
from transformers import PretrainedConfig
def patch_flash_attn_import(filename: str) -> List[str]:
"""Work around for https://huggingface.co/microsoft/phi-1_5/discussions/72."""
imports = get_imports(filename)
if "flash_attn" in imports:
imports.remove("flash_attn")
return imports
def ignore_argument(kwargs: dict, key: "str"):
arg = kwargs.pop(key, None)
if arg is not None:
warnings.warn(f"argument `{key}={arg}` will be ignored")
def generate(
self,
inputs: Optional[torch.Tensor] = None,
generation_config: Optional[GenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]]=None,
synced_gpus: Optional[bool] = None,
assistant_model: Optional["PreTrainedModel"] = None,
streamer: Optional["BaseStreamer"] = None,
**kwargs,
):
new_generate_kwargs = {}
for var in ['max_new_tokens', 'attention_mask', 'eos_token_id']:
value = kwargs.pop(var, None)
if value is not None:
new_generate_kwargs[var] = value
if isinstance(inputs[0], torch.Tensor):
numpy_input = inputs[0].numpy()
else:
numpy_input = inputs[0]
input_length = numpy.size(numpy_input)
new_tokens = new_generate_kwargs['max_new_tokens']
invalidInputError(input_length + new_tokens <= self.kv_len + 1,
"Input plus output tokens should not exceed max_output_len.")
# start generate_serve by Thread
thread = threading.Thread(target=generate_serve,
args=(self.kv_len, self.num_head,
self.head_dim, self.num_layers,
new_tokens))
thread.start()
in_pipe_path = "\\\\.\\pipe\\llminputpipe"
out_pipe_path = "\\\\.\\pipe\\llmoutputpipe"
while True:
try:
input_pipe = open(in_pipe_path, "wb")
except:
print('Waiting for input pipe')
time.sleep(1)
else:
break
while True:
try:
output_pipe = open(out_pipe_path, "rb")
except:
print('Waiting for output pipe')
time.sleep(1)
else:
break
bdata = b''
for i in range(0, input_length):
d = int(numpy_input[i])
bdata = bdata + d.to_bytes(4, sys.byteorder)
if "eos_token_id" not in new_generate_kwargs:
eos = 0xffffffff
else:
eos = new_generate_kwargs["eos_token_id"]
bdata = bdata + eos.to_bytes(4, sys.byteorder)
input_pipe.write(bytearray(bdata))
input_pipe.flush()
buffersize = 4
output_tokens = []
while True:
data = output_pipe.read(buffersize)
if len(data) == 0:
break
token = int.from_bytes(data, sys.byteorder)
output_tokens.append(torch.tensor([token]))
if streamer is not None:
streamer.put(torch.tensor([token]))
if token == eos:
break
output = torch.stack(output_tokens, dim=1)
if streamer is not None:
streamer.end()
thread.join()
return output
class NPUModel():
def __init__(self):
pass
class _BaseAutoModelClass:
HF_MODEL = None
@classmethod
@patch("transformers.dynamic_module_utils.get_imports", patch_flash_attn_import)
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
"""
Load a model from a directory or the HF Hub.
The loaded model will run supported OPs on NPU, then run other OPs on CPU.
Three new arguments are added to extend Hugging Face's from_pretrained method as follows:
:param ov_model: boolean value, whether load blob files from specified directory.
If it's False, will convert HF model to specified blob format,
but which is not supported now. Default to True.
:param max_output_len: Maximum context length for whole generation, default to 1024.
:param model_name: Name prefix of the model weight bin file.
:return: a model instance
"""
ov_model = kwargs.get("ov_model", True)
max_output_len = kwargs.pop("max_output_len", 1024)
invalidInputError(ov_model,
"Original HF model is not supported now.")
invalidInputError(os.path.exists(pretrained_model_name_or_path),
"This directory does not exist, please double check it.")
config_json = os.path.join(pretrained_model_name_or_path, "config.json")
invalidInputError(os.path.exists(config_json),
"config.json is not found in current directory, please double check it.")
config = PretrainedConfig.from_json_file(config_json)
model = NPUModel()
model.kv_len = max_output_len - 1
model.num_head = config.num_attention_heads
model.head_dim = config.hidden_size // config.num_attention_heads
model.num_layers = config.num_hidden_layers
model.vocab_size = config.vocab_size
model_weight_dir = os.path.join(pretrained_model_name_or_path, "model_layer_weights")
model_name = kwargs.get("model_name", "Model")
first_blob_name = os.path.join(pretrained_model_name_or_path, "first_model.blob")
last_blob_name = os.path.join(pretrained_model_name_or_path, "last_model.blob")
rest_blob_name = os.path.join(pretrained_model_name_or_path, "rest_model.blob")
for path in [model_weight_dir, first_blob_name, last_blob_name, rest_blob_name]:
invalidInputError(os.path.exists(path),
f"{path} is not found in current directory, please double check it.")
try:
res = InitLLMPipeline(model.kv_len, model.num_head, model.head_dim, model.num_layers,
model.vocab_size, model_weight_dir, model_name,
first_blob_name, last_blob_name, rest_blob_name)
except:
invalidInputError(False,
"False to InitLLMPipeline.")
exit(0)
# patch generate function
import types
model.generate = types.MethodType(generate, model)
return model
class AutoModelForCausalLM(_BaseAutoModelClass):
HF_Model = transformers.AutoModelForCausalLM
class AutoModel(_BaseAutoModelClass):
HF_Model = transformers.AutoModel
class AutoModelForSpeechSeq2Seq(_BaseAutoModelClass):
HF_Model = transformers.AutoModelForSpeechSeq2Seq
class AutoModelForSeq2SeqLM(_BaseAutoModelClass):
HF_Model = transformers.AutoModelForSeq2SeqLM
class AutoModelForSequenceClassification(_BaseAutoModelClass):
HF_Model = transformers.AutoModelForSequenceClassification
class AutoModelForMaskedLM(_BaseAutoModelClass):
HF_Model = transformers.AutoModelForMaskedLM
class AutoModelForQuestionAnswering(_BaseAutoModelClass):
HF_Model = transformers.AutoModelForQuestionAnswering
class AutoModelForNextSentencePrediction(_BaseAutoModelClass):
HF_Model = transformers.AutoModelForNextSentencePrediction
class AutoModelForMultipleChoice(_BaseAutoModelClass):
HF_Model = transformers.AutoModelForMultipleChoice
class AutoModelForTokenClassification(_BaseAutoModelClass):
HF_Model = transformers.AutoModelForTokenClassification