Support hf generate (#12477)

* generate

* style

* update

* remove timing

* style

* style

* combine generate api

* simple in kwargs
This commit is contained in:
Kai Huang 2024-12-04 16:31:09 +08:00 committed by GitHub
parent ef4028ac2d
commit 7ff4533b39
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 114 additions and 5 deletions

View file

@ -14,16 +14,17 @@
# limitations under the License.
from ipex_llm.utils.common.log4Error import invalidInputError
import os
import time
import torch
import importlib
from ipex_llm.transformers.npu_models.linear import QuantizedLinear
import time
from typing import Callable, List, Optional
from typing import Callable, List, Optional, Union, Tuple
from transformers import GenerationConfig, \
LogitsProcessorList, StoppingCriteriaList
from transformers.modeling_outputs import CausalLMOutputWithPast
from ipex_llm.transformers.utils import module_name_process
from ipex_llm.transformers.npu_models.linear import QuantizedLinear
from ipex_llm.utils.common.log4Error import invalidInputError
def module_optimization(func) -> torch.nn.Module:
@ -134,6 +135,14 @@ def convert_forward(m, target_m, new_forward):
convert_forward(sub_m, target_m, new_forward)
def general_convert(m, target_m, new_func, func_name="forward"):
if isinstance(m, target_m):
bound_method = new_func.__get__(m, m.__class__)
setattr(m, func_name, bound_method)
for _, sub_m in m.named_children():
general_convert(sub_m, target_m, new_func, func_name)
def optimize_llm(model: torch.nn.Module):
if model.config.model_type == "llama":
from ipex_llm.transformers.npu_models.llama import merge_qkv
@ -298,10 +307,32 @@ def generate(
generation_config: Optional[GenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]]=None,
prefix_allowed_tokens_fn: Optional[Callable] = None,
synced_gpus: Optional[bool] = None,
assistant_model: Optional["PreTrainedModel"] = None,
streamer: Optional["BaseStreamer"] = None,
negative_prompt_ids: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
**kwargs,
):
simple = kwargs.pop("simple", True)
if simple:
return simple_generate(self, inputs=inputs, streamer=streamer, **kwargs)
else:
return self.original_generate(inputs=inputs, generation_config=generation_config,
logits_processor=logits_processor,
stopping_criteria=stopping_criteria,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
synced_gpus=synced_gpus, assistant_model=assistant_model,
streamer=streamer, negative_prompt_ids=negative_prompt_ids,
negative_prompt_attention_mask=negative_prompt_attention_mask,
**kwargs)
def simple_generate(
self,
inputs: Optional[torch.Tensor] = None,
streamer: Optional["BaseStreamer"] = None,
**kwargs,
):
# if do_print=True, output timing message
@ -407,10 +438,64 @@ def optimize_llm_single_process(
model.model_ptr = model_ptr
model.save_directory = save_directory
model.vocab_size = model.config.vocab_size
model.logits_buffer = torch.empty(1, 1, model.vocab_size, dtype=torch.float32)
except:
invalidInputError(False,
"False to InitLLMPipeline.")
# patch model forward
from transformers.modeling_utils import PreTrainedModel
general_convert(model, PreTrainedModel, prepare_input_ids, "prepare_inputs_for_generation")
general_convert(model, PreTrainedModel, causal_lm_forward)
# patch generate function
import types
model.original_generate = model.generate
model.generate = types.MethodType(generate, model)
return model
def prepare_input_ids(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs):
if past_key_values is not None: # kvcache
input_ids = input_ids[:, -1]
else: # prefill, reset the model here
from .npu_llm_cpp import reset
reset(self.model_ptr)
model_inputs = {
"input_ids": input_ids
}
return model_inputs
def causal_lm_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
from .npu_llm_cpp import run_prefill_with_logits, run_decode_with_logits
if isinstance(input_ids[0], torch.Tensor):
input_list = input_ids[0].flatten().tolist()
else:
input_list = input_ids[0]
input_length = len(input_list)
if input_length > 1:
logits = run_prefill_with_logits(self.model_ptr, input_list,
self.logits_buffer, self.vocab_size)
else:
logits = run_decode_with_logits(self.model_ptr, input_list[0],
self.logits_buffer, self.vocab_size)
return CausalLMOutputWithPast(
loss=None,
logits=logits,
past_key_values=1, # just an indicator
hidden_states=None,
attentions=None,
)

View file

@ -60,6 +60,14 @@ _lib.llm_sample_token.restype = ctypes.c_int
_lib.reset.argtypes = [ctypes.c_void_p]
_lib.reset.restype = None
_lib.run_prefill_with_logits.argtypes = [ctypes.c_void_p, ctypes.POINTER(ctypes.c_int),
ctypes.c_int, ctypes.POINTER(ctypes.c_float), ctypes.c_int]
_lib.run_prefill_with_logits.restype = None
_lib.run_decode_with_logits.argtypes = [ctypes.c_void_p, ctypes.c_int,
ctypes.POINTER(ctypes.c_float), ctypes.c_int]
_lib.run_decode_with_logits.restype = None
def load_model_from_file(model_dir: str):
return _lib.load_model_from_file(model_dir.encode('utf-8'))
@ -79,5 +87,21 @@ def run_decode(model_ptr, input_id, vocab_size):
return new_token
def run_prefill_with_logits(model_ptr, input_ids, logits, vocab_size):
input_ptr = (ctypes.c_int32 * len(input_ids))(*input_ids)
input_len = len(input_ids)
logits_ptr = logits.data.data_ptr()
logits_ptr = ctypes.cast(logits_ptr, ctypes.POINTER(ctypes.c_float))
_lib.run_prefill_with_logits(model_ptr, input_ptr, input_len, logits_ptr, vocab_size)
return logits
def run_decode_with_logits(model_ptr, input_id, logits, vocab_size):
logits_ptr = logits.data.data_ptr()
logits_ptr = ctypes.cast(logits_ptr, ctypes.POINTER(ctypes.c_float))
_lib.run_decode_with_logits(model_ptr, input_id, logits_ptr, vocab_size)
return logits
def reset(model_ptr):
_lib.reset(model_ptr)