diff --git a/python/llm/src/ipex_llm/transformers/npu_models/convert.py b/python/llm/src/ipex_llm/transformers/npu_models/convert.py index 2abe3141..55a7d091 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/convert.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/convert.py @@ -14,16 +14,17 @@ # limitations under the License. -from ipex_llm.utils.common.log4Error import invalidInputError import os +import time import torch import importlib -from ipex_llm.transformers.npu_models.linear import QuantizedLinear -import time -from typing import Callable, List, Optional +from typing import Callable, List, Optional, Union, Tuple from transformers import GenerationConfig, \ LogitsProcessorList, StoppingCriteriaList +from transformers.modeling_outputs import CausalLMOutputWithPast from ipex_llm.transformers.utils import module_name_process +from ipex_llm.transformers.npu_models.linear import QuantizedLinear +from ipex_llm.utils.common.log4Error import invalidInputError def module_optimization(func) -> torch.nn.Module: @@ -134,6 +135,14 @@ def convert_forward(m, target_m, new_forward): convert_forward(sub_m, target_m, new_forward) +def general_convert(m, target_m, new_func, func_name="forward"): + if isinstance(m, target_m): + bound_method = new_func.__get__(m, m.__class__) + setattr(m, func_name, bound_method) + for _, sub_m in m.named_children(): + general_convert(sub_m, target_m, new_func, func_name) + + def optimize_llm(model: torch.nn.Module): if model.config.model_type == "llama": from ipex_llm.transformers.npu_models.llama import merge_qkv @@ -298,10 +307,32 @@ def generate( generation_config: Optional[GenerationConfig] = None, logits_processor: Optional[LogitsProcessorList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, - prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]]=None, + prefix_allowed_tokens_fn: Optional[Callable] = None, synced_gpus: Optional[bool] = None, assistant_model: Optional["PreTrainedModel"] = None, streamer: Optional["BaseStreamer"] = None, + negative_prompt_ids: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + **kwargs, +): + simple = kwargs.pop("simple", True) + if simple: + return simple_generate(self, inputs=inputs, streamer=streamer, **kwargs) + else: + return self.original_generate(inputs=inputs, generation_config=generation_config, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, + synced_gpus=synced_gpus, assistant_model=assistant_model, + streamer=streamer, negative_prompt_ids=negative_prompt_ids, + negative_prompt_attention_mask=negative_prompt_attention_mask, + **kwargs) + + +def simple_generate( + self, + inputs: Optional[torch.Tensor] = None, + streamer: Optional["BaseStreamer"] = None, **kwargs, ): # if do_print=True, output timing message @@ -407,10 +438,64 @@ def optimize_llm_single_process( model.model_ptr = model_ptr model.save_directory = save_directory model.vocab_size = model.config.vocab_size + model.logits_buffer = torch.empty(1, 1, model.vocab_size, dtype=torch.float32) except: invalidInputError(False, "False to InitLLMPipeline.") + # patch model forward + from transformers.modeling_utils import PreTrainedModel + general_convert(model, PreTrainedModel, prepare_input_ids, "prepare_inputs_for_generation") + general_convert(model, PreTrainedModel, causal_lm_forward) # patch generate function import types + model.original_generate = model.generate model.generate = types.MethodType(generate, model) return model + + +def prepare_input_ids( + self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs): + if past_key_values is not None: # kvcache + input_ids = input_ids[:, -1] + else: # prefill, reset the model here + from .npu_llm_cpp import reset + reset(self.model_ptr) + model_inputs = { + "input_ids": input_ids + } + return model_inputs + + +def causal_lm_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, +) -> Union[Tuple, CausalLMOutputWithPast]: + from .npu_llm_cpp import run_prefill_with_logits, run_decode_with_logits + if isinstance(input_ids[0], torch.Tensor): + input_list = input_ids[0].flatten().tolist() + else: + input_list = input_ids[0] + input_length = len(input_list) + if input_length > 1: + logits = run_prefill_with_logits(self.model_ptr, input_list, + self.logits_buffer, self.vocab_size) + else: + logits = run_decode_with_logits(self.model_ptr, input_list[0], + self.logits_buffer, self.vocab_size) + + return CausalLMOutputWithPast( + loss=None, + logits=logits, + past_key_values=1, # just an indicator + hidden_states=None, + attentions=None, + ) diff --git a/python/llm/src/ipex_llm/transformers/npu_models/npu_llm_cpp.py b/python/llm/src/ipex_llm/transformers/npu_models/npu_llm_cpp.py index 9507a753..8dca443f 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/npu_llm_cpp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/npu_llm_cpp.py @@ -60,6 +60,14 @@ _lib.llm_sample_token.restype = ctypes.c_int _lib.reset.argtypes = [ctypes.c_void_p] _lib.reset.restype = None +_lib.run_prefill_with_logits.argtypes = [ctypes.c_void_p, ctypes.POINTER(ctypes.c_int), + ctypes.c_int, ctypes.POINTER(ctypes.c_float), ctypes.c_int] +_lib.run_prefill_with_logits.restype = None + +_lib.run_decode_with_logits.argtypes = [ctypes.c_void_p, ctypes.c_int, + ctypes.POINTER(ctypes.c_float), ctypes.c_int] +_lib.run_decode_with_logits.restype = None + def load_model_from_file(model_dir: str): return _lib.load_model_from_file(model_dir.encode('utf-8')) @@ -79,5 +87,21 @@ def run_decode(model_ptr, input_id, vocab_size): return new_token +def run_prefill_with_logits(model_ptr, input_ids, logits, vocab_size): + input_ptr = (ctypes.c_int32 * len(input_ids))(*input_ids) + input_len = len(input_ids) + logits_ptr = logits.data.data_ptr() + logits_ptr = ctypes.cast(logits_ptr, ctypes.POINTER(ctypes.c_float)) + _lib.run_prefill_with_logits(model_ptr, input_ptr, input_len, logits_ptr, vocab_size) + return logits + + +def run_decode_with_logits(model_ptr, input_id, logits, vocab_size): + logits_ptr = logits.data.data_ptr() + logits_ptr = ctypes.cast(logits_ptr, ctypes.POINTER(ctypes.c_float)) + _lib.run_decode_with_logits(model_ptr, input_id, logits_ptr, vocab_size) + return logits + + def reset(model_ptr): _lib.reset(model_ptr)