diff --git a/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/CPP_Examples/llm-npu-cli.cpp b/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/CPP_Examples/llm-npu-cli.cpp index 06e19af4..b35a5bd2 100644 --- a/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/CPP_Examples/llm-npu-cli.cpp +++ b/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/CPP_Examples/llm-npu-cli.cpp @@ -34,7 +34,7 @@ int main(int argc, char ** argv) { common_params params; // path to the npu model directory - std::string model_dir; + char* model_dir; // prompt to generate text from std::string prompt = "AI是什么?"; // number of tokens to predict @@ -69,7 +69,7 @@ int main(int argc, char ** argv) { break; } } - if (model_dir.empty()) { + if (model_dir == nullptr || model_dir[0] == '\0') { print_usage(argc, argv); return 1; } @@ -86,8 +86,9 @@ int main(int argc, char ** argv) { params.model = model_dir; params.prompt = prompt; + void* model = load_model_from_file(params.model); npu_model_params model_params; - NPUModel* model = load_model_from_file(model_params, params.model); + load_config_from_file(model_params, params.model); tokenizer_params tok_params; load_tokenizer(tok_params, params.model); @@ -101,8 +102,8 @@ int main(int argc, char ** argv) { std::vector embd; // output ids auto start = std::chrono::high_resolution_clock::now(); - float* logits = run_prefill(model, embd_inp); - int32_t token = llm_sample_token(logits, true, model_params); + float* logits = run_prefill(model, embd_inp.data(), embd_inp.size()); + int32_t token = llm_sample_token(logits, true, model_params.vocab_size); auto end = std::chrono::high_resolution_clock::now(); auto duration = std::chrono::duration_cast(end - start); printf("\nPrefill %d tokens cost %d ms.\n", embd_inp.size(), duration.count()); @@ -112,7 +113,7 @@ int main(int argc, char ** argv) { start = std::chrono::high_resolution_clock::now(); for (int i = 1; i < params.n_predict; i++){ auto logits = run_decode(model, embd[i-1]); - int32_t token = llm_sample_token(logits, true, model_params); + int32_t token = llm_sample_token(logits, true, model_params.vocab_size); if (std::find(tok_params.eos_token_id.begin(), tok_params.eos_token_id.end(), token) == tok_params.eos_token_id.end()){ embd.push_back(token); token_nums ++; @@ -131,4 +132,4 @@ int main(int argc, char ** argv) { printf("\nDecode %d tokens cost %d ms (avg %f ms each token).\n", token_nums, duration.count(), (float)duration.count() / token_nums); return 0; -} \ No newline at end of file +} diff --git a/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/qwen.py b/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/qwen.py index 2256be57..57a2aa2b 100644 --- a/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/qwen.py +++ b/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/qwen.py @@ -54,7 +54,13 @@ if __name__ == "__main__": parser.add_argument("--disable-transpose-value-cache", action="store_true", default=False) parser.add_argument("--intra-pp", type=int, default=None) parser.add_argument("--inter-pp", type=int, default=None) - parser.add_argument("--mixed-precision", action='store_true') + parser.add_argument("--mixed-precision", action='store_false') + parser.add_argument("--save-directory", type=str, + required=True, + help="The path of folder to save converted model, " + "If path not exists, lowbit model will be saved there. " + "Else, program will raise error.", + ) args = parser.parse_args() model_path = args.repo_id_or_model_path @@ -74,6 +80,7 @@ if __name__ == "__main__": transpose_value_cache=not args.disable_transpose_value_cache, mixed_precision=args.mixed_precision, quantization_group_size=args.quantization_group_size, + save_directory=args.save_directory ) else: model = AutoModelForCausalLM.load_low_bit( diff --git a/python/llm/src/ipex_llm/transformers/npu_model.py b/python/llm/src/ipex_llm/transformers/npu_model.py index fc3c0879..aff61122 100644 --- a/python/llm/src/ipex_llm/transformers/npu_model.py +++ b/python/llm/src/ipex_llm/transformers/npu_model.py @@ -266,15 +266,28 @@ class _BaseAutoModelClass: model.share_memory() if not pipeline: - optimize_llm( - llm, - max_context_len=max_context_len, - max_prompt_len=max_prompt_len, - inter_pp=inter_pp, - intra_pp=intra_pp, - transpose_value_cache=transpose_value_cache, - group_size=quantization_group_size - ) + if model.config.model_type in ["qwen2"]: + from ipex_llm.transformers.npu_models.convert import optimize_llm_single_process + optimize_llm_single_process( + llm, + kv_len=max_context_len, + max_prompt_len=max_prompt_len, + transpose_value_cache=transpose_value_cache, + group_size=quantization_group_size, + qtype=qtype, + save_directory=save_directory, + fuse_layers=fuse_layers + ) + else: + optimize_llm( + llm, + max_context_len=max_context_len, + max_prompt_len=max_prompt_len, + inter_pp=inter_pp, + intra_pp=intra_pp, + transpose_value_cache=transpose_value_cache, + group_size=quantization_group_size + ) else: from ipex_llm.transformers.npu_pipeline_model.convert_pipeline \ import convert_llm diff --git a/python/llm/src/ipex_llm/transformers/npu_models/convert.py b/python/llm/src/ipex_llm/transformers/npu_models/convert.py index a3e949c6..461ec731 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/convert.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/convert.py @@ -14,10 +14,16 @@ # limitations under the License. +from ipex_llm.utils.common.log4Error import invalidInputError import os import torch import importlib from ipex_llm.transformers.npu_models.linear import QuantizedLinear +import tempfile +import time +from typing import Callable, List, Optional +from transformers import GenerationConfig, \ + LogitsProcessorList, StoppingCriteriaList def module_optimization(func) -> torch.nn.Module: @@ -265,3 +271,110 @@ def optimize_llm_post(model: torch.nn.Module): in_features=model.lm_head.in_features).to("cpu") new_linear._parameters['weight'] = paramsLowBit model.lm_head = new_linear + + +def generate( + self, + inputs: Optional[torch.Tensor] = None, + generation_config: Optional[GenerationConfig] = None, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]]=None, + synced_gpus: Optional[bool] = None, + assistant_model: Optional["PreTrainedModel"] = None, + streamer: Optional["BaseStreamer"] = None, + **kwargs, +): + # if do_print=True, output timing message + do_print = kwargs.pop("do_print", False) + time_t1 = time.perf_counter() + new_generate_kwargs = {} + for var in ['max_new_tokens', 'attention_mask', 'eos_token_id']: + value = kwargs.pop(var, None) + if value is not None: + new_generate_kwargs[var] = value + + if isinstance(inputs[0], torch.Tensor): + input_list = inputs[0].flatten().tolist() + else: + input_list = inputs[0] + input_length = len(input_list) + + new_tokens = new_generate_kwargs['max_new_tokens'] + invalidInputError(input_length + new_tokens <= self.kv_len + 1, + "Input plus output tokens should not exceed max_context_len.") + # TODO: may optimize this part later + invalidInputError(new_tokens < 1024, + f"Generated tokens ({new_tokens}) exceed named pipeline limitation.") + + if "eos_token_id" not in new_generate_kwargs: + eos = 0xffffffff + else: + eos = new_generate_kwargs["eos_token_id"] + output_tokens = [] + from .npu_llm_cpp import run_decode, run_prefill, reset + + token = run_prefill(self.model_ptr, input_list, self.vocab_size) + idx = 1 + time_t2 = time.perf_counter() + output_tokens.append(torch.tensor([token])) + for i in range(new_tokens - 1): + if token == eos: + break + token = run_decode(self.model_ptr, token, self.vocab_size) + idx += 1 + output_tokens.append(torch.tensor([token])) + output = torch.stack(output_tokens, dim=1) + output = torch.cat((inputs, output), dim=1) + time_t3 = time.perf_counter() + + reset(self.model_ptr) + self.first_cost = time_t2 - time_t1 # seconds + self.rest_cost_mean = (time_t3 - time_t2) / (idx - 1) # seconds + self.encoder_time = 0.0 + + if do_print: + print(f" Number of input tokens: {input_length}") + print(f" Generated tokens: {idx}") + print(f" First token generation time: {(time_t2 - time_t1):.2f} s") + print(f" Generation average latency: {(time_t3 - time_t2) * 1000 /(idx - 1):.2f} ms, " + f"({(idx - 1)/(time_t3 - time_t2):.2f} token/s)") + print(f" Generation time: {(time_t3 - time_t1):.2f} s\n") + + return output + + +def optimize_llm_single_process( + model: torch.nn.Module, + kv_len: int, + max_prompt_len: int, + transpose_value_cache: bool, + group_size: int, + qtype: str, + save_directory: str, + fuse_layers: int=None +): + from ipex_llm.transformers.npu_pipeline_model.convert_pipeline import convert_llm + from .npu_llm_cpp import load_model_from_file + + convert_llm(model, + kv_len=kv_len, + max_prompt_len=max_prompt_len, + transpose_value_cache=transpose_value_cache, + group_size=group_size, + qtype=qtype, + convert_model=True, + save_directory=save_directory, + fuse_layers=fuse_layers) + try: + model_ptr = load_model_from_file(save_directory) + model.kv_len = kv_len + model.model_ptr = model_ptr + model.vocab_size = model.config.vocab_size + except: + invalidInputError(False, + "False to InitLLMPipeline.") + # patch generate function + import types + model.generate = types.MethodType(generate, model) + return model diff --git a/python/llm/src/ipex_llm/transformers/npu_models/npu_llm_cpp.py b/python/llm/src/ipex_llm/transformers/npu_models/npu_llm_cpp.py new file mode 100644 index 00000000..9507a753 --- /dev/null +++ b/python/llm/src/ipex_llm/transformers/npu_models/npu_llm_cpp.py @@ -0,0 +1,83 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import sys +import ctypes +import pathlib +from ipex_llm.utils.common import invalidInputError + + +def get_shared_lib_info(lib_base_name: str): + # Determine the file extension based on the platform + if sys.platform.startswith("linux") or sys.platform == "darwin": + lib_ext = ".so" + elif sys.platform == "win32": + lib_ext = ".dll" + else: + invalidInputError(False, "Unsupported platform.") + + # Construct the paths to the possible shared library names + import importlib + module = importlib.import_module("bigdl-core-npu") + _base_path = pathlib.Path(module.__file__).parent.resolve() + + lib_path = os.path.join(_base_path, lib_base_name + lib_ext) + + return _base_path, lib_path + + +_, _lib_path = get_shared_lib_info("npu_llm") + +# Load the library +_lib = ctypes.cdll.LoadLibrary(_lib_path) + +_lib.load_model_from_file.argtypes = [ctypes.c_char_p] +_lib.load_model_from_file.restype = ctypes.c_void_p + +_lib.run_prefill.argtypes = [ctypes.c_void_p, ctypes.POINTER(ctypes.c_int), ctypes.c_int] +_lib.run_prefill.restype = ctypes.POINTER(ctypes.c_float) + +_lib.run_decode.argtypes = [ctypes.c_void_p, ctypes.c_int] +_lib.run_decode.restype = ctypes.POINTER(ctypes.c_float) + +_lib.llm_sample_token.argtypes = [ctypes.POINTER(ctypes.c_float), ctypes.c_bool, ctypes.c_int] +_lib.llm_sample_token.restype = ctypes.c_int + +_lib.reset.argtypes = [ctypes.c_void_p] +_lib.reset.restype = None + + +def load_model_from_file(model_dir: str): + return _lib.load_model_from_file(model_dir.encode('utf-8')) + + +def run_prefill(model_ptr, input_ids, vocab_size): + input_ptr = (ctypes.c_int32 * len(input_ids))(*input_ids) + input_len = len(input_ids) + plogits = _lib.run_prefill(model_ptr, input_ptr, input_len) + new_token = _lib.llm_sample_token(plogits, True, vocab_size) + return new_token + + +def run_decode(model_ptr, input_id, vocab_size): + plogits = _lib.run_decode(model_ptr, input_id) + new_token = _lib.llm_sample_token(plogits, True, vocab_size) + return new_token + + +def reset(model_ptr): + _lib.reset(model_ptr) diff --git a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/convert_pipeline.py b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/convert_pipeline.py index 1c736623..c299adff 100644 --- a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/convert_pipeline.py +++ b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/convert_pipeline.py @@ -426,9 +426,11 @@ def convert_llm_for_deploy(model: torch.nn.Module, group_size: int, save_directory: str=None, fuse_layers: int=None): - os.mkdir(save_directory) + if not os.path.exists(save_directory): + os.mkdir(save_directory) weight_dir = os.path.join(save_directory, "model_weights") - os.mkdir(weight_dir) + if not os.path.exists(weight_dir): + os.mkdir(weight_dir) layernorm_const = os.environ.get("IPEX_LLM_NPU_LAYERNORM_CONST", "1") == "1" if model.config.model_type == "qwen2":