Initial integrate our L0 Llama impl into ipex-llm (#12255)
* temp save * initial support * fix * simplify code * fix style * fix example * make default value of pipeline as False
This commit is contained in:
parent
cacc891962
commit
821fd96367
7 changed files with 481 additions and 303 deletions
|
|
@ -18,7 +18,7 @@
|
|||
import torch
|
||||
import time
|
||||
import argparse
|
||||
from ipex_llm.transformers.npu_pipeline_model import AutoModelForCausalLM
|
||||
from ipex_llm.transformers.npu_model import AutoModelForCausalLM
|
||||
from transformers import AutoTokenizer
|
||||
from transformers.utils import logging
|
||||
|
||||
|
|
@ -44,8 +44,9 @@ if __name__ == "__main__":
|
|||
parser.add_argument(
|
||||
"--repo-id-or-model-path",
|
||||
type=str,
|
||||
default=r"C:\\Llama2-converted-weights\\",
|
||||
help="The folder path of converted model blobs",
|
||||
default="meta-llama/Llama-2-7b-chat-hf",
|
||||
help="The huggingface repo id for the Llama2 model to be downloaded"
|
||||
", or the path to the huggingface checkpoint folder",
|
||||
)
|
||||
parser.add_argument('--prompt', type=str, default="What is AI?",
|
||||
help='Prompt to infer')
|
||||
|
|
@ -56,9 +57,9 @@ if __name__ == "__main__":
|
|||
model_path = args.repo_id_or_model_path
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_path,
|
||||
ov_model=True,
|
||||
max_output_len=args.max_output_len,
|
||||
model_name="Model70")
|
||||
optimize_model=True,
|
||||
pipeline=True,
|
||||
max_output_len=args.max_output_len)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
|
||||
|
|
@ -68,22 +69,23 @@ if __name__ == "__main__":
|
|||
print("-" * 80)
|
||||
print("done")
|
||||
with torch.inference_mode():
|
||||
print("finish to load")
|
||||
prompt = get_prompt(args.prompt, [], system_prompt=DEFAULT_SYSTEM_PROMPT)
|
||||
_input_ids = tokenizer.encode(prompt, return_tensors="pt")
|
||||
print("input length:", len(_input_ids[0]))
|
||||
st = time.time()
|
||||
output = model.generate(
|
||||
_input_ids, max_new_tokens=args.n_predict,
|
||||
)
|
||||
end = time.time()
|
||||
print(f"Inference time: {end-st} s")
|
||||
input_str = tokenizer.decode(_input_ids[0], skip_special_tokens=False)
|
||||
print("-" * 20, "Input", "-" * 20)
|
||||
print(input_str)
|
||||
output_str = tokenizer.decode(output[0], skip_special_tokens=False)
|
||||
print("-" * 20, "Output", "-" * 20)
|
||||
print(output_str)
|
||||
for i in range(5):
|
||||
print("finish to load")
|
||||
prompt = get_prompt(args.prompt, [], system_prompt=DEFAULT_SYSTEM_PROMPT)
|
||||
_input_ids = tokenizer.encode(prompt, return_tensors="pt")
|
||||
print("input length:", len(_input_ids[0]))
|
||||
st = time.time()
|
||||
output = model.generate(
|
||||
_input_ids, max_new_tokens=args.n_predict, do_print=True
|
||||
)
|
||||
end = time.time()
|
||||
print(f"Inference time: {end-st} s")
|
||||
input_str = tokenizer.decode(_input_ids[0], skip_special_tokens=False)
|
||||
print("-" * 20, "Input", "-" * 20)
|
||||
print(input_str)
|
||||
output_str = tokenizer.decode(output[0], skip_special_tokens=False)
|
||||
print("-" * 20, "Output", "-" * 20)
|
||||
print(output_str)
|
||||
|
||||
print("-" * 80)
|
||||
print("done")
|
||||
|
|
|
|||
|
|
@ -120,6 +120,7 @@ class _BaseAutoModelClass:
|
|||
ignore_argument(kwargs, "speculative")
|
||||
ignore_argument(kwargs, "pipeline_parallel_stages")
|
||||
optimize_model = kwargs.pop("optimize_model", False)
|
||||
pipeline = kwargs.pop("pipeline", False)
|
||||
max_output_len = kwargs.pop("max_output_len", 1024)
|
||||
max_output_len = max_output_len - 1
|
||||
max_prompt_len = kwargs.pop("max_prompt_len", 512)
|
||||
|
|
@ -184,16 +185,22 @@ class _BaseAutoModelClass:
|
|||
model.config.update({"bigdl_transformers_low_bit": qtype})
|
||||
model.share_memory()
|
||||
|
||||
optimize_llm(
|
||||
llm,
|
||||
max_output_len=max_output_len,
|
||||
max_prompt_len=max_prompt_len,
|
||||
inter_pp=inter_pp,
|
||||
intra_pp=intra_pp,
|
||||
transpose_value_cache=transpose_value_cache,
|
||||
group_size=quantization_group_size
|
||||
)
|
||||
model.save_low_bit = types.MethodType(save_low_bit, model)
|
||||
if not pipeline:
|
||||
optimize_llm(
|
||||
llm,
|
||||
max_output_len=max_output_len,
|
||||
max_prompt_len=max_prompt_len,
|
||||
inter_pp=inter_pp,
|
||||
intra_pp=intra_pp,
|
||||
transpose_value_cache=transpose_value_cache,
|
||||
group_size=quantization_group_size
|
||||
)
|
||||
model.save_low_bit = types.MethodType(save_low_bit, model)
|
||||
else:
|
||||
from ipex_llm.transformers.npu_pipeline_model.convert_pipeline import convert_llm
|
||||
convert_llm(llm,
|
||||
kv_len=max_output_len,
|
||||
transpose_value_cache=transpose_value_cache)
|
||||
else:
|
||||
from ipex_llm.transformers.npu_models.convert import optimize_llm
|
||||
optimize_llm(model)
|
||||
|
|
|
|||
|
|
@ -13,5 +13,3 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from .pipeline_model import *
|
||||
|
|
|
|||
|
|
@ -0,0 +1,322 @@
|
|||
#
|
||||
# Copyright 2016 The BigDL Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
|
||||
from openvino.runtime import Core, serialize
|
||||
import os
|
||||
import torch
|
||||
from ipex_llm.utils.common import invalidInputError
|
||||
import time
|
||||
import sys
|
||||
from typing import List
|
||||
from .pipeline_cpp import InitLLMPipeline, generate_serve
|
||||
from typing import Callable, List, Optional
|
||||
from transformers import GenerationConfig, \
|
||||
LogitsProcessorList, StoppingCriteriaList
|
||||
import threading
|
||||
from ipex_llm.utils.common import invalidInputError
|
||||
import tempfile
|
||||
import numpy as np
|
||||
|
||||
|
||||
def generate(
|
||||
self,
|
||||
inputs: Optional[torch.Tensor] = None,
|
||||
generation_config: Optional[GenerationConfig] = None,
|
||||
logits_processor: Optional[LogitsProcessorList] = None,
|
||||
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
||||
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]]=None,
|
||||
synced_gpus: Optional[bool] = None,
|
||||
assistant_model: Optional["PreTrainedModel"] = None,
|
||||
streamer: Optional["BaseStreamer"] = None,
|
||||
**kwargs,
|
||||
):
|
||||
# if do_print=True, output timing message
|
||||
do_print = kwargs.pop("do_print", False)
|
||||
time_start_all, time_t1, idx = time.perf_counter(), None, 0
|
||||
new_generate_kwargs = {}
|
||||
for var in ['max_new_tokens', 'attention_mask', 'eos_token_id']:
|
||||
value = kwargs.pop(var, None)
|
||||
if value is not None:
|
||||
new_generate_kwargs[var] = value
|
||||
|
||||
if isinstance(inputs[0], torch.Tensor):
|
||||
numpy_input = inputs[0].numpy()
|
||||
else:
|
||||
numpy_input = inputs[0]
|
||||
input_length = np.size(numpy_input)
|
||||
|
||||
new_tokens = new_generate_kwargs['max_new_tokens']
|
||||
invalidInputError(input_length + new_tokens <= self.kv_len + 1,
|
||||
"Input plus output tokens should not exceed max_output_len.")
|
||||
|
||||
# start generate_serve by Thread
|
||||
thread = threading.Thread(target=generate_serve,
|
||||
args=(self.kv_len, self.num_head,
|
||||
self.head_dim, self.num_layers,
|
||||
self.transpose_value_cache,
|
||||
new_tokens - 1))
|
||||
thread.start()
|
||||
|
||||
in_pipe_path = "\\\\.\\pipe\\llminputpipe"
|
||||
out_pipe_path = "\\\\.\\pipe\\llmoutputpipe"
|
||||
|
||||
while True:
|
||||
try:
|
||||
input_pipe = open(in_pipe_path, "wb")
|
||||
except:
|
||||
print('Waiting for input pipe')
|
||||
time.sleep(1)
|
||||
else:
|
||||
break
|
||||
|
||||
while True:
|
||||
try:
|
||||
output_pipe = open(out_pipe_path, "rb")
|
||||
except:
|
||||
print('Waiting for output pipe')
|
||||
time.sleep(1)
|
||||
else:
|
||||
break
|
||||
|
||||
bdata = b''
|
||||
for i in range(0, input_length):
|
||||
d = int(numpy_input[i])
|
||||
bdata = bdata + d.to_bytes(4, sys.byteorder)
|
||||
|
||||
if "eos_token_id" not in new_generate_kwargs:
|
||||
eos = 0xffffffff
|
||||
else:
|
||||
eos = new_generate_kwargs["eos_token_id"]
|
||||
|
||||
bdata = bdata + eos.to_bytes(4, sys.byteorder)
|
||||
|
||||
time_start = time.perf_counter()
|
||||
|
||||
input_pipe.write(bytearray(bdata))
|
||||
input_pipe.flush()
|
||||
|
||||
buffersize = 4
|
||||
output_tokens = []
|
||||
while True:
|
||||
data = output_pipe.read(buffersize)
|
||||
if len(data) == 0:
|
||||
break
|
||||
token = int.from_bytes(data, sys.byteorder)
|
||||
idx += 1
|
||||
if time_t1 is None:
|
||||
time_t1 = time.perf_counter()
|
||||
output_tokens.append(torch.tensor([token]))
|
||||
if streamer is not None:
|
||||
streamer.put(torch.tensor([token]))
|
||||
if token == eos:
|
||||
break
|
||||
|
||||
output = torch.stack(output_tokens, dim=1)
|
||||
output = torch.cat((inputs, output), dim=1)
|
||||
if streamer is not None:
|
||||
streamer.end()
|
||||
|
||||
thread.join()
|
||||
time_end = time.perf_counter()
|
||||
|
||||
if do_print:
|
||||
print(f" Start the thread and connect the pipe time: {(time_start - time_start_all):.2f} s")
|
||||
print(f" Number of input tokens: {input_length}")
|
||||
print(f" Generated tokens: {idx}")
|
||||
print(f" First token generation time: {(time_t1 - time_start):.2f} s")
|
||||
print(f" Generation average latency: {(time_end - time_t1)*1000 /(idx - 1):.2f} ms, "
|
||||
f"({(idx - 1)/(time_end - time_t1):.2f} token/s)")
|
||||
print(f" Generation time: {(time_end - time_start):.2f} s\n")
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def update_names_of_IR_and_export_blob(model, model_name, dir):
|
||||
xml_path = os.path.join(dir, model_name + ".xml")
|
||||
model.save(xml_path)
|
||||
new_ir_path = os.path.join(dir, model_name + "_new.xml")
|
||||
blob_path = os.path.join(dir, model_name + ".blob")
|
||||
|
||||
core = Core()
|
||||
core.set_property("NPU", {"NPU_COMPILATION_MODE_PARAMS":
|
||||
"compute-layers-with-higher-precision=Sqrt,Power,ReduceMean,Add"})
|
||||
core.set_property("NPU", {"PERFORMANCE_HINT": "LATENCY"})
|
||||
model = core.read_model(xml_path)
|
||||
inputs = model.inputs
|
||||
for idx, input in enumerate(inputs):
|
||||
if len(input.names) == 0:
|
||||
model.inputs[idx].set_names({f"input_{idx}"})
|
||||
outputs = model.outputs
|
||||
for idx, input in enumerate(outputs):
|
||||
if len(input.names) == 0:
|
||||
model.outputs[idx].set_names({f"output_{idx}"})
|
||||
# rewrite this model to a new IR path
|
||||
if new_ir_path is not None:
|
||||
serialize(model, new_ir_path)
|
||||
|
||||
if blob_path is not None:
|
||||
compiledModel = core.compile_model(model, device_name="NPU")
|
||||
model_stream = compiledModel.export_model()
|
||||
with open(blob_path, 'wb') as f:
|
||||
f.write(model_stream)
|
||||
|
||||
os.remove(xml_path)
|
||||
os.remove(new_ir_path)
|
||||
|
||||
return blob_path
|
||||
|
||||
|
||||
def convert_llm(model: torch.nn.Module,
|
||||
kv_len: int,
|
||||
transpose_value_cache: bool):
|
||||
if model.config.model_type == "llama":
|
||||
from .llama import LowBitLlamaLMHead, LlamaEmbedding
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# generate lm_head blob
|
||||
weight_dir = os.path.join(temp_dir, "model_weights")
|
||||
os.mkdir(weight_dir)
|
||||
num_heads = model.model.layers[0].self_attn.num_heads
|
||||
num_key_value_heads = model.model.layers[0].self_attn.num_key_value_heads
|
||||
head_dim = model.model.layers[0].self_attn.head_dim
|
||||
intermediate_size = model.config.intermediate_size
|
||||
layer_num = len(model.model.layers)
|
||||
rms_norm_eps = model.config.rms_norm_eps
|
||||
vocab_size = model.config.vocab_size
|
||||
model_norm = model.model.norm
|
||||
lm_head = model.lm_head
|
||||
weights = [(lm_head.weight, lm_head.scale)]
|
||||
if isinstance(weights[0], tuple):
|
||||
np_dtype = np.int8 if weights[0][0].dtype == torch.int8 else np.uint8
|
||||
else: # FP16 Linear
|
||||
np_dtype = np.float16
|
||||
|
||||
new_lm_head = LowBitLlamaLMHead(
|
||||
[1, 1, num_heads * head_dim],
|
||||
num_heads=num_heads,
|
||||
num_key_value_heads=num_key_value_heads,
|
||||
max_seq_len=kv_len,
|
||||
rms_norm_eps=rms_norm_eps,
|
||||
mode="decode",
|
||||
transpose_value=False,
|
||||
dtype=np_dtype,
|
||||
model_norm_weight=model_norm.weight.to(torch.float16),
|
||||
vocab_size=vocab_size,
|
||||
)
|
||||
last_blob_path = update_names_of_IR_and_export_blob(new_lm_head, "lm_head", temp_dir)
|
||||
|
||||
# save weights bins files
|
||||
weight_numpy = [
|
||||
lm_head.weight.data.numpy(), lm_head.scale.data.numpy(),
|
||||
]
|
||||
|
||||
for idx, weight in enumerate(weight_numpy):
|
||||
bin_file = os.path.join(weight_dir, f"model_lm_head_input_{1+idx}.bin")
|
||||
weight.tofile(bin_file)
|
||||
|
||||
embedding_layer = model.model.embed_tokens
|
||||
new_embedding = LlamaEmbedding(
|
||||
vocab_size=model.config.vocab_size,
|
||||
embedding_dim=model.config.hidden_size,
|
||||
padding_idx=model.config.pad_token_id,
|
||||
dtype=np.float16,
|
||||
)
|
||||
first_blob_path = update_names_of_IR_and_export_blob(new_embedding, "embedding",
|
||||
temp_dir)
|
||||
bin_file = os.path.join(weight_dir, f"model_embedding_input_0.bin")
|
||||
embedding_layer.weight.to(torch.float16).detach().numpy().tofile(bin_file)
|
||||
|
||||
# generate decoder layer blob
|
||||
from ipex_llm.transformers.npu_models.llama_mp import LowBitLlamaMultiDecoderlayer
|
||||
for layer_idx in range(0, layer_num):
|
||||
curr_layer = model.model.layers[layer_idx]
|
||||
attn_layer = curr_layer.self_attn
|
||||
mlp_layer = curr_layer.mlp
|
||||
|
||||
weights = [
|
||||
(attn_layer.q_proj.weight, attn_layer.q_proj.scale),
|
||||
(attn_layer.k_proj.weight, attn_layer.k_proj.scale),
|
||||
(attn_layer.v_proj.weight, attn_layer.v_proj.scale),
|
||||
(attn_layer.o_proj.weight, attn_layer.o_proj.scale),
|
||||
(mlp_layer.gate_proj.weight, mlp_layer.gate_proj.scale),
|
||||
(mlp_layer.up_proj.weight, mlp_layer.up_proj.scale),
|
||||
(mlp_layer.down_proj.weight, mlp_layer.down_proj.scale),
|
||||
]
|
||||
|
||||
cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
|
||||
cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
|
||||
layer_norm_0 = curr_layer.input_layernorm.weight.to(torch.float16)
|
||||
layer_norm_1 = curr_layer.post_attention_layernorm.weight.to(torch.float16)
|
||||
|
||||
if isinstance(weights[0], tuple):
|
||||
np_dtype = np.int8 if weights[0][0].dtype == torch.int8 else np.uint8
|
||||
else: # FP16 Linear
|
||||
np_dtype = np.float16
|
||||
|
||||
if layer_idx == 0:
|
||||
single_decoder = LowBitLlamaMultiDecoderlayer(
|
||||
[1, 1, num_heads * head_dim],
|
||||
input_layernorm_weights=None,
|
||||
post_attn_layernorm_weights=None,
|
||||
cached_cos=cached_cos,
|
||||
cached_sin=cached_sin,
|
||||
num_heads=num_heads,
|
||||
num_key_value_heads=num_key_value_heads,
|
||||
num_layers=1,
|
||||
max_seq_len=kv_len,
|
||||
rms_norm_eps=rms_norm_eps,
|
||||
intermediate_size=intermediate_size,
|
||||
mode="decode",
|
||||
transpose_value=transpose_value_cache,
|
||||
dtype=np_dtype,
|
||||
)
|
||||
rest_blob_path = update_names_of_IR_and_export_blob(single_decoder,
|
||||
"decoder_layer",
|
||||
temp_dir)
|
||||
|
||||
input_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_3.bin")
|
||||
post_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_4.bin")
|
||||
layer_norm_0.data.numpy().tofile(input_lm_bin_file)
|
||||
layer_norm_1.data.numpy().tofile(post_lm_bin_file)
|
||||
|
||||
for idx, (weight, scale) in enumerate(weights):
|
||||
bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{7+idx*2}.bin")
|
||||
weight.numpy().tofile(bin_file)
|
||||
bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{7+idx*2+1}.bin")
|
||||
scale.numpy().tofile(bin_file)
|
||||
|
||||
# patch attrs for generate
|
||||
model.kv_len = kv_len
|
||||
model.num_head = num_heads
|
||||
model.head_dim = head_dim
|
||||
model.num_layers = layer_num
|
||||
model.transpose_value_cache = transpose_value_cache
|
||||
|
||||
try:
|
||||
res = InitLLMPipeline(kv_len, num_heads, head_dim, layer_num,
|
||||
model.vocab_size, weight_dir, "model",
|
||||
first_blob_path, last_blob_path, rest_blob_path)
|
||||
except:
|
||||
invalidInputError(False,
|
||||
"False to InitLLMPipeline.")
|
||||
else:
|
||||
invalidInputError(False,
|
||||
"Now we only support Llama2 for pipeline running.")
|
||||
|
||||
# patch generate function
|
||||
import types
|
||||
model.generate = types.MethodType(generate, model)
|
||||
return model
|
||||
114
python/llm/src/ipex_llm/transformers/npu_pipeline_model/llama.py
Normal file
114
python/llm/src/ipex_llm/transformers/npu_pipeline_model/llama.py
Normal file
|
|
@ -0,0 +1,114 @@
|
|||
#
|
||||
# Copyright 2016 The BigDL Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
|
||||
import numpy as np
|
||||
from ipex_llm.transformers.npu_models.mp_models_base import LLMBaseNNFactory
|
||||
from typing import Sequence
|
||||
from intel_npu_acceleration_library.backend.factory import NNFactory
|
||||
|
||||
|
||||
class LowBitLlamaLMHead(LLMBaseNNFactory):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_shape: Sequence[int],
|
||||
num_heads: int,
|
||||
num_key_value_heads: int,
|
||||
rms_norm_eps: float,
|
||||
model_norm_weight,
|
||||
vocab_size: int,
|
||||
mode: str = "decode",
|
||||
dtype: np.dtype = np.int8,
|
||||
max_seq_len: int = 1024,
|
||||
transpose_value: bool = False,
|
||||
profile: bool = False,
|
||||
device: str = "NPU",
|
||||
):
|
||||
super().__init__(max_seq_len=max_seq_len,
|
||||
transpose_value=transpose_value,
|
||||
dtype=dtype,
|
||||
profile=profile,
|
||||
device=device)
|
||||
self.max_seq_len = max_seq_len
|
||||
self.dtype = dtype
|
||||
self.batch_size, self.seq_len, self.hidden_size = hidden_shape
|
||||
self.mode = mode
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.transpose_value = transpose_value
|
||||
self.vocab_size = vocab_size
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
|
||||
self.head_dim = self.hidden_size // self.num_heads
|
||||
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
||||
|
||||
# define input, the order self.parameter matters
|
||||
input = self.create_input_op((self.batch_size, self.seq_len, self.hidden_size))
|
||||
|
||||
hidden_states = input
|
||||
|
||||
# model norm and lm head
|
||||
model_norm_weight = self.constant(model_norm_weight)
|
||||
hidden_states = self.layer_norm(hidden_states, model_norm_weight)
|
||||
hidden_states = self.linear(
|
||||
hidden_states, self.vocab_size, self.hidden_size, bias=False, wt_dtype=self.dtype
|
||||
)
|
||||
|
||||
# define outputs
|
||||
hidden_states = self.convert_to_fp32(hidden_states)
|
||||
|
||||
print("start compiling")
|
||||
self.compile()
|
||||
|
||||
|
||||
class LlamaEmbedding(NNFactory):
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size,
|
||||
embedding_dim,
|
||||
padding_idx,
|
||||
dtype, # fp16
|
||||
device: str = "NPU",
|
||||
):
|
||||
super().__init__(False, device)
|
||||
self.vocab_size = vocab_size
|
||||
self.embedding_dim = embedding_dim
|
||||
self.padding_idx = padding_idx
|
||||
self.dtype = dtype
|
||||
|
||||
# define input
|
||||
weight = self.parameter((vocab_size, embedding_dim))
|
||||
input = self.parameter((1, 1), dtype=np.int32)
|
||||
|
||||
if padding_idx == -1:
|
||||
padding_idx += vocab_size
|
||||
|
||||
if padding_idx is not None:
|
||||
masked_embeddings = np.ones(weight.shape, dtype='int64')
|
||||
masked_embeddings[padding_idx, :] = 0 # mask
|
||||
|
||||
node_mask = self.constant(masked_embeddings)
|
||||
node_masked_w = self.matmul(weight, node_mask, False, True)
|
||||
|
||||
axis_node = self.constant(np.array([0], dtype=np.int64))
|
||||
res = self.gather(node_masked_w if padding_idx else weight, input, axis_node, 0)
|
||||
|
||||
# define outputs
|
||||
res = self.convert_to_fp16(res)
|
||||
|
||||
print("start compiling")
|
||||
self.compile()
|
||||
|
|
@ -46,7 +46,7 @@ _lib = ctypes.cdll.LoadLibrary(_lib_path)
|
|||
_lib.InitLLMPipeline.argtypes = [ctypes.c_int] * 5 + [ctypes.c_char_p] * 5
|
||||
_lib.InitLLMPipeline.restype = ctypes.c_int
|
||||
|
||||
_lib.generate_serve.argtypes = [ctypes.c_int] * 5
|
||||
_lib.generate_serve.argtypes = [ctypes.c_int] * 4 + [ctypes.c_bool] + [ctypes.c_int]
|
||||
_lib.generate_serve.restype = ctypes.c_int
|
||||
|
||||
|
||||
|
|
@ -60,5 +60,6 @@ def InitLLMPipeline(kv_len: int, num_head: int, head_dim: int, num_layers: int,
|
|||
|
||||
|
||||
def generate_serve(kv_len: int, num_head: int, head_dim: int, num_layers: int,
|
||||
param_n_output: int):
|
||||
_lib.generate_serve(kv_len, num_head, head_dim, num_layers, param_n_output)
|
||||
transpose_value_cache: bool, param_n_output: int):
|
||||
_lib.generate_serve(kv_len, num_head, head_dim, num_layers,
|
||||
transpose_value_cache, param_n_output)
|
||||
|
|
|
|||
|
|
@ -1,266 +0,0 @@
|
|||
#
|
||||
# Copyright 2016 The BigDL Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import time
|
||||
import numpy
|
||||
import warnings
|
||||
import torch
|
||||
import sys
|
||||
import transformers
|
||||
from typing import List
|
||||
from unittest.mock import patch
|
||||
from transformers.dynamic_module_utils import get_imports
|
||||
from .pipeline_cpp import InitLLMPipeline, generate_serve
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
from transformers import GenerationConfig, \
|
||||
LogitsProcessorList, StoppingCriteriaList
|
||||
import threading
|
||||
from ipex_llm.utils.common import invalidInputError
|
||||
import os
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
|
||||
def patch_flash_attn_import(filename: str) -> List[str]:
|
||||
"""Work around for https://huggingface.co/microsoft/phi-1_5/discussions/72."""
|
||||
imports = get_imports(filename)
|
||||
if "flash_attn" in imports:
|
||||
imports.remove("flash_attn")
|
||||
return imports
|
||||
|
||||
|
||||
def ignore_argument(kwargs: dict, key: "str"):
|
||||
arg = kwargs.pop(key, None)
|
||||
if arg is not None:
|
||||
warnings.warn(f"argument `{key}={arg}` will be ignored")
|
||||
|
||||
|
||||
def generate(
|
||||
self,
|
||||
inputs: Optional[torch.Tensor] = None,
|
||||
generation_config: Optional[GenerationConfig] = None,
|
||||
logits_processor: Optional[LogitsProcessorList] = None,
|
||||
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
||||
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]]=None,
|
||||
synced_gpus: Optional[bool] = None,
|
||||
assistant_model: Optional["PreTrainedModel"] = None,
|
||||
streamer: Optional["BaseStreamer"] = None,
|
||||
**kwargs,
|
||||
):
|
||||
# if do_print=True, output timing message
|
||||
do_print = kwargs.pop("do_print", False)
|
||||
time_start_all, time_t1, idx = time.perf_counter(), None, 0
|
||||
new_generate_kwargs = {}
|
||||
for var in ['max_new_tokens', 'attention_mask', 'eos_token_id']:
|
||||
value = kwargs.pop(var, None)
|
||||
if value is not None:
|
||||
new_generate_kwargs[var] = value
|
||||
|
||||
if isinstance(inputs[0], torch.Tensor):
|
||||
numpy_input = inputs[0].numpy()
|
||||
else:
|
||||
numpy_input = inputs[0]
|
||||
input_length = numpy.size(numpy_input)
|
||||
|
||||
new_tokens = new_generate_kwargs['max_new_tokens']
|
||||
invalidInputError(input_length + new_tokens <= self.kv_len + 1,
|
||||
"Input plus output tokens should not exceed max_output_len.")
|
||||
|
||||
# start generate_serve by Thread
|
||||
thread = threading.Thread(target=generate_serve,
|
||||
args=(self.kv_len, self.num_head,
|
||||
self.head_dim, self.num_layers,
|
||||
new_tokens - 1))
|
||||
thread.start()
|
||||
|
||||
in_pipe_path = "\\\\.\\pipe\\llminputpipe"
|
||||
out_pipe_path = "\\\\.\\pipe\\llmoutputpipe"
|
||||
|
||||
while True:
|
||||
try:
|
||||
input_pipe = open(in_pipe_path, "wb")
|
||||
except:
|
||||
print('Waiting for input pipe')
|
||||
time.sleep(1)
|
||||
else:
|
||||
break
|
||||
|
||||
while True:
|
||||
try:
|
||||
output_pipe = open(out_pipe_path, "rb")
|
||||
except:
|
||||
print('Waiting for output pipe')
|
||||
time.sleep(1)
|
||||
else:
|
||||
break
|
||||
|
||||
bdata = b''
|
||||
for i in range(0, input_length):
|
||||
d = int(numpy_input[i])
|
||||
bdata = bdata + d.to_bytes(4, sys.byteorder)
|
||||
|
||||
if "eos_token_id" not in new_generate_kwargs:
|
||||
eos = 0xffffffff
|
||||
else:
|
||||
eos = new_generate_kwargs["eos_token_id"]
|
||||
|
||||
bdata = bdata + eos.to_bytes(4, sys.byteorder)
|
||||
|
||||
time_start = time.perf_counter()
|
||||
|
||||
input_pipe.write(bytearray(bdata))
|
||||
input_pipe.flush()
|
||||
|
||||
buffersize = 4
|
||||
output_tokens = []
|
||||
while True:
|
||||
data = output_pipe.read(buffersize)
|
||||
if len(data) == 0:
|
||||
break
|
||||
token = int.from_bytes(data, sys.byteorder)
|
||||
idx += 1
|
||||
if time_t1 is None:
|
||||
time_t1 = time.perf_counter()
|
||||
output_tokens.append(torch.tensor([token]))
|
||||
if streamer is not None:
|
||||
streamer.put(torch.tensor([token]))
|
||||
if token == eos:
|
||||
break
|
||||
|
||||
output = torch.stack(output_tokens, dim=1)
|
||||
output = torch.cat((inputs, output), dim=1)
|
||||
if streamer is not None:
|
||||
streamer.end()
|
||||
|
||||
thread.join()
|
||||
time_end = time.perf_counter()
|
||||
|
||||
if do_print:
|
||||
print(f" Start the thread and connect the pipe time: {(time_start - time_start_all):.2f} s")
|
||||
print(f" Number of input tokens: {input_length}")
|
||||
print(f" Generated tokens: {idx}")
|
||||
print(f" First token generation time: {(time_t1 - time_start):.2f} s")
|
||||
print(f" Generation average latency: {(time_end - time_t1)*1000 /(idx - 1):.2f} ms, "
|
||||
f"({(idx - 1)/(time_end - time_t1):.2f} token/s)")
|
||||
print(f" Generation time: {(time_end - time_start):.2f} s\n")
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class NPUModel():
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
class _BaseAutoModelClass:
|
||||
HF_MODEL = None
|
||||
|
||||
@classmethod
|
||||
@patch("transformers.dynamic_module_utils.get_imports", patch_flash_attn_import)
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
|
||||
"""
|
||||
Load a model from a directory or the HF Hub.
|
||||
The loaded model will run supported OPs on NPU, then run other OPs on CPU.
|
||||
|
||||
Three new arguments are added to extend Hugging Face's from_pretrained method as follows:
|
||||
:param ov_model: boolean value, whether load blob files from specified directory.
|
||||
If it's False, will convert HF model to specified blob format,
|
||||
but which is not supported now. Default to True.
|
||||
:param max_output_len: Maximum context length for whole generation, default to 1024.
|
||||
:param model_name: Name prefix of the model weight bin file.
|
||||
:return: a model instance
|
||||
"""
|
||||
ov_model = kwargs.get("ov_model", True)
|
||||
max_output_len = kwargs.pop("max_output_len", 1024)
|
||||
|
||||
invalidInputError(ov_model,
|
||||
"Original HF model is not supported now.")
|
||||
invalidInputError(os.path.exists(pretrained_model_name_or_path),
|
||||
"This directory does not exist, please double check it.")
|
||||
|
||||
config_json = os.path.join(pretrained_model_name_or_path, "config.json")
|
||||
invalidInputError(os.path.exists(config_json),
|
||||
"config.json is not found in current directory, please double check it.")
|
||||
config = PretrainedConfig.from_json_file(config_json)
|
||||
model = NPUModel()
|
||||
model.kv_len = max_output_len - 1
|
||||
model.num_head = config.num_attention_heads
|
||||
model.head_dim = config.hidden_size // config.num_attention_heads
|
||||
model.num_layers = config.num_hidden_layers
|
||||
model.vocab_size = config.vocab_size
|
||||
|
||||
model_weight_dir = os.path.join(pretrained_model_name_or_path, "model_layer_weights")
|
||||
model_name = kwargs.get("model_name", "Model")
|
||||
first_blob_name = os.path.join(pretrained_model_name_or_path, "first_model.blob")
|
||||
last_blob_name = os.path.join(pretrained_model_name_or_path, "last_model.blob")
|
||||
rest_blob_name = os.path.join(pretrained_model_name_or_path, "rest_model.blob")
|
||||
|
||||
for path in [model_weight_dir, first_blob_name, last_blob_name, rest_blob_name]:
|
||||
invalidInputError(os.path.exists(path),
|
||||
f"{path} is not found in current directory, please double check it.")
|
||||
|
||||
try:
|
||||
res = InitLLMPipeline(model.kv_len, model.num_head, model.head_dim, model.num_layers,
|
||||
model.vocab_size, model_weight_dir, model_name,
|
||||
first_blob_name, last_blob_name, rest_blob_name)
|
||||
except:
|
||||
invalidInputError(False,
|
||||
"False to InitLLMPipeline.")
|
||||
exit(0)
|
||||
|
||||
# patch generate function
|
||||
import types
|
||||
model.generate = types.MethodType(generate, model)
|
||||
return model
|
||||
|
||||
|
||||
class AutoModelForCausalLM(_BaseAutoModelClass):
|
||||
HF_Model = transformers.AutoModelForCausalLM
|
||||
|
||||
|
||||
class AutoModel(_BaseAutoModelClass):
|
||||
HF_Model = transformers.AutoModel
|
||||
|
||||
|
||||
class AutoModelForSpeechSeq2Seq(_BaseAutoModelClass):
|
||||
HF_Model = transformers.AutoModelForSpeechSeq2Seq
|
||||
|
||||
|
||||
class AutoModelForSeq2SeqLM(_BaseAutoModelClass):
|
||||
HF_Model = transformers.AutoModelForSeq2SeqLM
|
||||
|
||||
|
||||
class AutoModelForSequenceClassification(_BaseAutoModelClass):
|
||||
HF_Model = transformers.AutoModelForSequenceClassification
|
||||
|
||||
|
||||
class AutoModelForMaskedLM(_BaseAutoModelClass):
|
||||
HF_Model = transformers.AutoModelForMaskedLM
|
||||
|
||||
|
||||
class AutoModelForQuestionAnswering(_BaseAutoModelClass):
|
||||
HF_Model = transformers.AutoModelForQuestionAnswering
|
||||
|
||||
|
||||
class AutoModelForNextSentencePrediction(_BaseAutoModelClass):
|
||||
HF_Model = transformers.AutoModelForNextSentencePrediction
|
||||
|
||||
|
||||
class AutoModelForMultipleChoice(_BaseAutoModelClass):
|
||||
HF_Model = transformers.AutoModelForMultipleChoice
|
||||
|
||||
|
||||
class AutoModelForTokenClassification(_BaseAutoModelClass):
|
||||
HF_Model = transformers.AutoModelForTokenClassification
|
||||
Loading…
Reference in a new issue