Initial integrate our L0 Llama impl into ipex-llm (#12255)

* temp save

* initial support

* fix

* simplify code

* fix style

* fix example

* make default value of pipeline as False
This commit is contained in:
Ruonan Wang 2024-10-24 09:49:27 +08:00 committed by GitHub
parent cacc891962
commit 821fd96367
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 481 additions and 303 deletions

View file

@ -18,7 +18,7 @@
import torch
import time
import argparse
from ipex_llm.transformers.npu_pipeline_model import AutoModelForCausalLM
from ipex_llm.transformers.npu_model import AutoModelForCausalLM
from transformers import AutoTokenizer
from transformers.utils import logging
@ -44,8 +44,9 @@ if __name__ == "__main__":
parser.add_argument(
"--repo-id-or-model-path",
type=str,
default=r"C:\\Llama2-converted-weights\\",
help="The folder path of converted model blobs",
default="meta-llama/Llama-2-7b-chat-hf",
help="The huggingface repo id for the Llama2 model to be downloaded"
", or the path to the huggingface checkpoint folder",
)
parser.add_argument('--prompt', type=str, default="What is AI?",
help='Prompt to infer')
@ -56,9 +57,9 @@ if __name__ == "__main__":
model_path = args.repo_id_or_model_path
model = AutoModelForCausalLM.from_pretrained(model_path,
ov_model=True,
max_output_len=args.max_output_len,
model_name="Model70")
optimize_model=True,
pipeline=True,
max_output_len=args.max_output_len)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
@ -68,22 +69,23 @@ if __name__ == "__main__":
print("-" * 80)
print("done")
with torch.inference_mode():
print("finish to load")
prompt = get_prompt(args.prompt, [], system_prompt=DEFAULT_SYSTEM_PROMPT)
_input_ids = tokenizer.encode(prompt, return_tensors="pt")
print("input length:", len(_input_ids[0]))
st = time.time()
output = model.generate(
_input_ids, max_new_tokens=args.n_predict,
)
end = time.time()
print(f"Inference time: {end-st} s")
input_str = tokenizer.decode(_input_ids[0], skip_special_tokens=False)
print("-" * 20, "Input", "-" * 20)
print(input_str)
output_str = tokenizer.decode(output[0], skip_special_tokens=False)
print("-" * 20, "Output", "-" * 20)
print(output_str)
for i in range(5):
print("finish to load")
prompt = get_prompt(args.prompt, [], system_prompt=DEFAULT_SYSTEM_PROMPT)
_input_ids = tokenizer.encode(prompt, return_tensors="pt")
print("input length:", len(_input_ids[0]))
st = time.time()
output = model.generate(
_input_ids, max_new_tokens=args.n_predict, do_print=True
)
end = time.time()
print(f"Inference time: {end-st} s")
input_str = tokenizer.decode(_input_ids[0], skip_special_tokens=False)
print("-" * 20, "Input", "-" * 20)
print(input_str)
output_str = tokenizer.decode(output[0], skip_special_tokens=False)
print("-" * 20, "Output", "-" * 20)
print(output_str)
print("-" * 80)
print("done")

View file

@ -120,6 +120,7 @@ class _BaseAutoModelClass:
ignore_argument(kwargs, "speculative")
ignore_argument(kwargs, "pipeline_parallel_stages")
optimize_model = kwargs.pop("optimize_model", False)
pipeline = kwargs.pop("pipeline", False)
max_output_len = kwargs.pop("max_output_len", 1024)
max_output_len = max_output_len - 1
max_prompt_len = kwargs.pop("max_prompt_len", 512)
@ -184,16 +185,22 @@ class _BaseAutoModelClass:
model.config.update({"bigdl_transformers_low_bit": qtype})
model.share_memory()
optimize_llm(
llm,
max_output_len=max_output_len,
max_prompt_len=max_prompt_len,
inter_pp=inter_pp,
intra_pp=intra_pp,
transpose_value_cache=transpose_value_cache,
group_size=quantization_group_size
)
model.save_low_bit = types.MethodType(save_low_bit, model)
if not pipeline:
optimize_llm(
llm,
max_output_len=max_output_len,
max_prompt_len=max_prompt_len,
inter_pp=inter_pp,
intra_pp=intra_pp,
transpose_value_cache=transpose_value_cache,
group_size=quantization_group_size
)
model.save_low_bit = types.MethodType(save_low_bit, model)
else:
from ipex_llm.transformers.npu_pipeline_model.convert_pipeline import convert_llm
convert_llm(llm,
kv_len=max_output_len,
transpose_value_cache=transpose_value_cache)
else:
from ipex_llm.transformers.npu_models.convert import optimize_llm
optimize_llm(model)

View file

@ -13,5 +13,3 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from .pipeline_model import *

View file

@ -0,0 +1,322 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from openvino.runtime import Core, serialize
import os
import torch
from ipex_llm.utils.common import invalidInputError
import time
import sys
from typing import List
from .pipeline_cpp import InitLLMPipeline, generate_serve
from typing import Callable, List, Optional
from transformers import GenerationConfig, \
LogitsProcessorList, StoppingCriteriaList
import threading
from ipex_llm.utils.common import invalidInputError
import tempfile
import numpy as np
def generate(
self,
inputs: Optional[torch.Tensor] = None,
generation_config: Optional[GenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]]=None,
synced_gpus: Optional[bool] = None,
assistant_model: Optional["PreTrainedModel"] = None,
streamer: Optional["BaseStreamer"] = None,
**kwargs,
):
# if do_print=True, output timing message
do_print = kwargs.pop("do_print", False)
time_start_all, time_t1, idx = time.perf_counter(), None, 0
new_generate_kwargs = {}
for var in ['max_new_tokens', 'attention_mask', 'eos_token_id']:
value = kwargs.pop(var, None)
if value is not None:
new_generate_kwargs[var] = value
if isinstance(inputs[0], torch.Tensor):
numpy_input = inputs[0].numpy()
else:
numpy_input = inputs[0]
input_length = np.size(numpy_input)
new_tokens = new_generate_kwargs['max_new_tokens']
invalidInputError(input_length + new_tokens <= self.kv_len + 1,
"Input plus output tokens should not exceed max_output_len.")
# start generate_serve by Thread
thread = threading.Thread(target=generate_serve,
args=(self.kv_len, self.num_head,
self.head_dim, self.num_layers,
self.transpose_value_cache,
new_tokens - 1))
thread.start()
in_pipe_path = "\\\\.\\pipe\\llminputpipe"
out_pipe_path = "\\\\.\\pipe\\llmoutputpipe"
while True:
try:
input_pipe = open(in_pipe_path, "wb")
except:
print('Waiting for input pipe')
time.sleep(1)
else:
break
while True:
try:
output_pipe = open(out_pipe_path, "rb")
except:
print('Waiting for output pipe')
time.sleep(1)
else:
break
bdata = b''
for i in range(0, input_length):
d = int(numpy_input[i])
bdata = bdata + d.to_bytes(4, sys.byteorder)
if "eos_token_id" not in new_generate_kwargs:
eos = 0xffffffff
else:
eos = new_generate_kwargs["eos_token_id"]
bdata = bdata + eos.to_bytes(4, sys.byteorder)
time_start = time.perf_counter()
input_pipe.write(bytearray(bdata))
input_pipe.flush()
buffersize = 4
output_tokens = []
while True:
data = output_pipe.read(buffersize)
if len(data) == 0:
break
token = int.from_bytes(data, sys.byteorder)
idx += 1
if time_t1 is None:
time_t1 = time.perf_counter()
output_tokens.append(torch.tensor([token]))
if streamer is not None:
streamer.put(torch.tensor([token]))
if token == eos:
break
output = torch.stack(output_tokens, dim=1)
output = torch.cat((inputs, output), dim=1)
if streamer is not None:
streamer.end()
thread.join()
time_end = time.perf_counter()
if do_print:
print(f" Start the thread and connect the pipe time: {(time_start - time_start_all):.2f} s")
print(f" Number of input tokens: {input_length}")
print(f" Generated tokens: {idx}")
print(f" First token generation time: {(time_t1 - time_start):.2f} s")
print(f" Generation average latency: {(time_end - time_t1)*1000 /(idx - 1):.2f} ms, "
f"({(idx - 1)/(time_end - time_t1):.2f} token/s)")
print(f" Generation time: {(time_end - time_start):.2f} s\n")
return output
def update_names_of_IR_and_export_blob(model, model_name, dir):
xml_path = os.path.join(dir, model_name + ".xml")
model.save(xml_path)
new_ir_path = os.path.join(dir, model_name + "_new.xml")
blob_path = os.path.join(dir, model_name + ".blob")
core = Core()
core.set_property("NPU", {"NPU_COMPILATION_MODE_PARAMS":
"compute-layers-with-higher-precision=Sqrt,Power,ReduceMean,Add"})
core.set_property("NPU", {"PERFORMANCE_HINT": "LATENCY"})
model = core.read_model(xml_path)
inputs = model.inputs
for idx, input in enumerate(inputs):
if len(input.names) == 0:
model.inputs[idx].set_names({f"input_{idx}"})
outputs = model.outputs
for idx, input in enumerate(outputs):
if len(input.names) == 0:
model.outputs[idx].set_names({f"output_{idx}"})
# rewrite this model to a new IR path
if new_ir_path is not None:
serialize(model, new_ir_path)
if blob_path is not None:
compiledModel = core.compile_model(model, device_name="NPU")
model_stream = compiledModel.export_model()
with open(blob_path, 'wb') as f:
f.write(model_stream)
os.remove(xml_path)
os.remove(new_ir_path)
return blob_path
def convert_llm(model: torch.nn.Module,
kv_len: int,
transpose_value_cache: bool):
if model.config.model_type == "llama":
from .llama import LowBitLlamaLMHead, LlamaEmbedding
with tempfile.TemporaryDirectory() as temp_dir:
# generate lm_head blob
weight_dir = os.path.join(temp_dir, "model_weights")
os.mkdir(weight_dir)
num_heads = model.model.layers[0].self_attn.num_heads
num_key_value_heads = model.model.layers[0].self_attn.num_key_value_heads
head_dim = model.model.layers[0].self_attn.head_dim
intermediate_size = model.config.intermediate_size
layer_num = len(model.model.layers)
rms_norm_eps = model.config.rms_norm_eps
vocab_size = model.config.vocab_size
model_norm = model.model.norm
lm_head = model.lm_head
weights = [(lm_head.weight, lm_head.scale)]
if isinstance(weights[0], tuple):
np_dtype = np.int8 if weights[0][0].dtype == torch.int8 else np.uint8
else: # FP16 Linear
np_dtype = np.float16
new_lm_head = LowBitLlamaLMHead(
[1, 1, num_heads * head_dim],
num_heads=num_heads,
num_key_value_heads=num_key_value_heads,
max_seq_len=kv_len,
rms_norm_eps=rms_norm_eps,
mode="decode",
transpose_value=False,
dtype=np_dtype,
model_norm_weight=model_norm.weight.to(torch.float16),
vocab_size=vocab_size,
)
last_blob_path = update_names_of_IR_and_export_blob(new_lm_head, "lm_head", temp_dir)
# save weights bins files
weight_numpy = [
lm_head.weight.data.numpy(), lm_head.scale.data.numpy(),
]
for idx, weight in enumerate(weight_numpy):
bin_file = os.path.join(weight_dir, f"model_lm_head_input_{1+idx}.bin")
weight.tofile(bin_file)
embedding_layer = model.model.embed_tokens
new_embedding = LlamaEmbedding(
vocab_size=model.config.vocab_size,
embedding_dim=model.config.hidden_size,
padding_idx=model.config.pad_token_id,
dtype=np.float16,
)
first_blob_path = update_names_of_IR_and_export_blob(new_embedding, "embedding",
temp_dir)
bin_file = os.path.join(weight_dir, f"model_embedding_input_0.bin")
embedding_layer.weight.to(torch.float16).detach().numpy().tofile(bin_file)
# generate decoder layer blob
from ipex_llm.transformers.npu_models.llama_mp import LowBitLlamaMultiDecoderlayer
for layer_idx in range(0, layer_num):
curr_layer = model.model.layers[layer_idx]
attn_layer = curr_layer.self_attn
mlp_layer = curr_layer.mlp
weights = [
(attn_layer.q_proj.weight, attn_layer.q_proj.scale),
(attn_layer.k_proj.weight, attn_layer.k_proj.scale),
(attn_layer.v_proj.weight, attn_layer.v_proj.scale),
(attn_layer.o_proj.weight, attn_layer.o_proj.scale),
(mlp_layer.gate_proj.weight, mlp_layer.gate_proj.scale),
(mlp_layer.up_proj.weight, mlp_layer.up_proj.scale),
(mlp_layer.down_proj.weight, mlp_layer.down_proj.scale),
]
cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
layer_norm_0 = curr_layer.input_layernorm.weight.to(torch.float16)
layer_norm_1 = curr_layer.post_attention_layernorm.weight.to(torch.float16)
if isinstance(weights[0], tuple):
np_dtype = np.int8 if weights[0][0].dtype == torch.int8 else np.uint8
else: # FP16 Linear
np_dtype = np.float16
if layer_idx == 0:
single_decoder = LowBitLlamaMultiDecoderlayer(
[1, 1, num_heads * head_dim],
input_layernorm_weights=None,
post_attn_layernorm_weights=None,
cached_cos=cached_cos,
cached_sin=cached_sin,
num_heads=num_heads,
num_key_value_heads=num_key_value_heads,
num_layers=1,
max_seq_len=kv_len,
rms_norm_eps=rms_norm_eps,
intermediate_size=intermediate_size,
mode="decode",
transpose_value=transpose_value_cache,
dtype=np_dtype,
)
rest_blob_path = update_names_of_IR_and_export_blob(single_decoder,
"decoder_layer",
temp_dir)
input_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_3.bin")
post_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_4.bin")
layer_norm_0.data.numpy().tofile(input_lm_bin_file)
layer_norm_1.data.numpy().tofile(post_lm_bin_file)
for idx, (weight, scale) in enumerate(weights):
bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{7+idx*2}.bin")
weight.numpy().tofile(bin_file)
bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{7+idx*2+1}.bin")
scale.numpy().tofile(bin_file)
# patch attrs for generate
model.kv_len = kv_len
model.num_head = num_heads
model.head_dim = head_dim
model.num_layers = layer_num
model.transpose_value_cache = transpose_value_cache
try:
res = InitLLMPipeline(kv_len, num_heads, head_dim, layer_num,
model.vocab_size, weight_dir, "model",
first_blob_path, last_blob_path, rest_blob_path)
except:
invalidInputError(False,
"False to InitLLMPipeline.")
else:
invalidInputError(False,
"Now we only support Llama2 for pipeline running.")
# patch generate function
import types
model.generate = types.MethodType(generate, model)
return model

View file

@ -0,0 +1,114 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import numpy as np
from ipex_llm.transformers.npu_models.mp_models_base import LLMBaseNNFactory
from typing import Sequence
from intel_npu_acceleration_library.backend.factory import NNFactory
class LowBitLlamaLMHead(LLMBaseNNFactory):
def __init__(
self,
hidden_shape: Sequence[int],
num_heads: int,
num_key_value_heads: int,
rms_norm_eps: float,
model_norm_weight,
vocab_size: int,
mode: str = "decode",
dtype: np.dtype = np.int8,
max_seq_len: int = 1024,
transpose_value: bool = False,
profile: bool = False,
device: str = "NPU",
):
super().__init__(max_seq_len=max_seq_len,
transpose_value=transpose_value,
dtype=dtype,
profile=profile,
device=device)
self.max_seq_len = max_seq_len
self.dtype = dtype
self.batch_size, self.seq_len, self.hidden_size = hidden_shape
self.mode = mode
self.rms_norm_eps = rms_norm_eps
self.transpose_value = transpose_value
self.vocab_size = vocab_size
self.num_heads = num_heads
self.num_key_value_heads = num_key_value_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
# define input, the order self.parameter matters
input = self.create_input_op((self.batch_size, self.seq_len, self.hidden_size))
hidden_states = input
# model norm and lm head
model_norm_weight = self.constant(model_norm_weight)
hidden_states = self.layer_norm(hidden_states, model_norm_weight)
hidden_states = self.linear(
hidden_states, self.vocab_size, self.hidden_size, bias=False, wt_dtype=self.dtype
)
# define outputs
hidden_states = self.convert_to_fp32(hidden_states)
print("start compiling")
self.compile()
class LlamaEmbedding(NNFactory):
def __init__(
self,
vocab_size,
embedding_dim,
padding_idx,
dtype, # fp16
device: str = "NPU",
):
super().__init__(False, device)
self.vocab_size = vocab_size
self.embedding_dim = embedding_dim
self.padding_idx = padding_idx
self.dtype = dtype
# define input
weight = self.parameter((vocab_size, embedding_dim))
input = self.parameter((1, 1), dtype=np.int32)
if padding_idx == -1:
padding_idx += vocab_size
if padding_idx is not None:
masked_embeddings = np.ones(weight.shape, dtype='int64')
masked_embeddings[padding_idx, :] = 0 # mask
node_mask = self.constant(masked_embeddings)
node_masked_w = self.matmul(weight, node_mask, False, True)
axis_node = self.constant(np.array([0], dtype=np.int64))
res = self.gather(node_masked_w if padding_idx else weight, input, axis_node, 0)
# define outputs
res = self.convert_to_fp16(res)
print("start compiling")
self.compile()

View file

@ -46,7 +46,7 @@ _lib = ctypes.cdll.LoadLibrary(_lib_path)
_lib.InitLLMPipeline.argtypes = [ctypes.c_int] * 5 + [ctypes.c_char_p] * 5
_lib.InitLLMPipeline.restype = ctypes.c_int
_lib.generate_serve.argtypes = [ctypes.c_int] * 5
_lib.generate_serve.argtypes = [ctypes.c_int] * 4 + [ctypes.c_bool] + [ctypes.c_int]
_lib.generate_serve.restype = ctypes.c_int
@ -60,5 +60,6 @@ def InitLLMPipeline(kv_len: int, num_head: int, head_dim: int, num_layers: int,
def generate_serve(kv_len: int, num_head: int, head_dim: int, num_layers: int,
param_n_output: int):
_lib.generate_serve(kv_len, num_head, head_dim, num_layers, param_n_output)
transpose_value_cache: bool, param_n_output: int):
_lib.generate_serve(kv_len, num_head, head_dim, num_layers,
transpose_value_cache, param_n_output)

View file

@ -1,266 +0,0 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import time
import numpy
import warnings
import torch
import sys
import transformers
from typing import List
from unittest.mock import patch
from transformers.dynamic_module_utils import get_imports
from .pipeline_cpp import InitLLMPipeline, generate_serve
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
from transformers import GenerationConfig, \
LogitsProcessorList, StoppingCriteriaList
import threading
from ipex_llm.utils.common import invalidInputError
import os
from transformers import PretrainedConfig
def patch_flash_attn_import(filename: str) -> List[str]:
"""Work around for https://huggingface.co/microsoft/phi-1_5/discussions/72."""
imports = get_imports(filename)
if "flash_attn" in imports:
imports.remove("flash_attn")
return imports
def ignore_argument(kwargs: dict, key: "str"):
arg = kwargs.pop(key, None)
if arg is not None:
warnings.warn(f"argument `{key}={arg}` will be ignored")
def generate(
self,
inputs: Optional[torch.Tensor] = None,
generation_config: Optional[GenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]]=None,
synced_gpus: Optional[bool] = None,
assistant_model: Optional["PreTrainedModel"] = None,
streamer: Optional["BaseStreamer"] = None,
**kwargs,
):
# if do_print=True, output timing message
do_print = kwargs.pop("do_print", False)
time_start_all, time_t1, idx = time.perf_counter(), None, 0
new_generate_kwargs = {}
for var in ['max_new_tokens', 'attention_mask', 'eos_token_id']:
value = kwargs.pop(var, None)
if value is not None:
new_generate_kwargs[var] = value
if isinstance(inputs[0], torch.Tensor):
numpy_input = inputs[0].numpy()
else:
numpy_input = inputs[0]
input_length = numpy.size(numpy_input)
new_tokens = new_generate_kwargs['max_new_tokens']
invalidInputError(input_length + new_tokens <= self.kv_len + 1,
"Input plus output tokens should not exceed max_output_len.")
# start generate_serve by Thread
thread = threading.Thread(target=generate_serve,
args=(self.kv_len, self.num_head,
self.head_dim, self.num_layers,
new_tokens - 1))
thread.start()
in_pipe_path = "\\\\.\\pipe\\llminputpipe"
out_pipe_path = "\\\\.\\pipe\\llmoutputpipe"
while True:
try:
input_pipe = open(in_pipe_path, "wb")
except:
print('Waiting for input pipe')
time.sleep(1)
else:
break
while True:
try:
output_pipe = open(out_pipe_path, "rb")
except:
print('Waiting for output pipe')
time.sleep(1)
else:
break
bdata = b''
for i in range(0, input_length):
d = int(numpy_input[i])
bdata = bdata + d.to_bytes(4, sys.byteorder)
if "eos_token_id" not in new_generate_kwargs:
eos = 0xffffffff
else:
eos = new_generate_kwargs["eos_token_id"]
bdata = bdata + eos.to_bytes(4, sys.byteorder)
time_start = time.perf_counter()
input_pipe.write(bytearray(bdata))
input_pipe.flush()
buffersize = 4
output_tokens = []
while True:
data = output_pipe.read(buffersize)
if len(data) == 0:
break
token = int.from_bytes(data, sys.byteorder)
idx += 1
if time_t1 is None:
time_t1 = time.perf_counter()
output_tokens.append(torch.tensor([token]))
if streamer is not None:
streamer.put(torch.tensor([token]))
if token == eos:
break
output = torch.stack(output_tokens, dim=1)
output = torch.cat((inputs, output), dim=1)
if streamer is not None:
streamer.end()
thread.join()
time_end = time.perf_counter()
if do_print:
print(f" Start the thread and connect the pipe time: {(time_start - time_start_all):.2f} s")
print(f" Number of input tokens: {input_length}")
print(f" Generated tokens: {idx}")
print(f" First token generation time: {(time_t1 - time_start):.2f} s")
print(f" Generation average latency: {(time_end - time_t1)*1000 /(idx - 1):.2f} ms, "
f"({(idx - 1)/(time_end - time_t1):.2f} token/s)")
print(f" Generation time: {(time_end - time_start):.2f} s\n")
return output
class NPUModel():
def __init__(self):
pass
class _BaseAutoModelClass:
HF_MODEL = None
@classmethod
@patch("transformers.dynamic_module_utils.get_imports", patch_flash_attn_import)
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
"""
Load a model from a directory or the HF Hub.
The loaded model will run supported OPs on NPU, then run other OPs on CPU.
Three new arguments are added to extend Hugging Face's from_pretrained method as follows:
:param ov_model: boolean value, whether load blob files from specified directory.
If it's False, will convert HF model to specified blob format,
but which is not supported now. Default to True.
:param max_output_len: Maximum context length for whole generation, default to 1024.
:param model_name: Name prefix of the model weight bin file.
:return: a model instance
"""
ov_model = kwargs.get("ov_model", True)
max_output_len = kwargs.pop("max_output_len", 1024)
invalidInputError(ov_model,
"Original HF model is not supported now.")
invalidInputError(os.path.exists(pretrained_model_name_or_path),
"This directory does not exist, please double check it.")
config_json = os.path.join(pretrained_model_name_or_path, "config.json")
invalidInputError(os.path.exists(config_json),
"config.json is not found in current directory, please double check it.")
config = PretrainedConfig.from_json_file(config_json)
model = NPUModel()
model.kv_len = max_output_len - 1
model.num_head = config.num_attention_heads
model.head_dim = config.hidden_size // config.num_attention_heads
model.num_layers = config.num_hidden_layers
model.vocab_size = config.vocab_size
model_weight_dir = os.path.join(pretrained_model_name_or_path, "model_layer_weights")
model_name = kwargs.get("model_name", "Model")
first_blob_name = os.path.join(pretrained_model_name_or_path, "first_model.blob")
last_blob_name = os.path.join(pretrained_model_name_or_path, "last_model.blob")
rest_blob_name = os.path.join(pretrained_model_name_or_path, "rest_model.blob")
for path in [model_weight_dir, first_blob_name, last_blob_name, rest_blob_name]:
invalidInputError(os.path.exists(path),
f"{path} is not found in current directory, please double check it.")
try:
res = InitLLMPipeline(model.kv_len, model.num_head, model.head_dim, model.num_layers,
model.vocab_size, model_weight_dir, model_name,
first_blob_name, last_blob_name, rest_blob_name)
except:
invalidInputError(False,
"False to InitLLMPipeline.")
exit(0)
# patch generate function
import types
model.generate = types.MethodType(generate, model)
return model
class AutoModelForCausalLM(_BaseAutoModelClass):
HF_Model = transformers.AutoModelForCausalLM
class AutoModel(_BaseAutoModelClass):
HF_Model = transformers.AutoModel
class AutoModelForSpeechSeq2Seq(_BaseAutoModelClass):
HF_Model = transformers.AutoModelForSpeechSeq2Seq
class AutoModelForSeq2SeqLM(_BaseAutoModelClass):
HF_Model = transformers.AutoModelForSeq2SeqLM
class AutoModelForSequenceClassification(_BaseAutoModelClass):
HF_Model = transformers.AutoModelForSequenceClassification
class AutoModelForMaskedLM(_BaseAutoModelClass):
HF_Model = transformers.AutoModelForMaskedLM
class AutoModelForQuestionAnswering(_BaseAutoModelClass):
HF_Model = transformers.AutoModelForQuestionAnswering
class AutoModelForNextSentencePrediction(_BaseAutoModelClass):
HF_Model = transformers.AutoModelForNextSentencePrediction
class AutoModelForMultipleChoice(_BaseAutoModelClass):
HF_Model = transformers.AutoModelForMultipleChoice
class AutoModelForTokenClassification(_BaseAutoModelClass):
HF_Model = transformers.AutoModelForTokenClassification