Support hf generate (#12477)
* generate * style * update * remove timing * style * style * combine generate api * simple in kwargs
This commit is contained in:
parent
ef4028ac2d
commit
7ff4533b39
2 changed files with 114 additions and 5 deletions
|
|
@ -14,16 +14,17 @@
|
|||
# limitations under the License.
|
||||
|
||||
|
||||
from ipex_llm.utils.common.log4Error import invalidInputError
|
||||
import os
|
||||
import time
|
||||
import torch
|
||||
import importlib
|
||||
from ipex_llm.transformers.npu_models.linear import QuantizedLinear
|
||||
import time
|
||||
from typing import Callable, List, Optional
|
||||
from typing import Callable, List, Optional, Union, Tuple
|
||||
from transformers import GenerationConfig, \
|
||||
LogitsProcessorList, StoppingCriteriaList
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
from ipex_llm.transformers.utils import module_name_process
|
||||
from ipex_llm.transformers.npu_models.linear import QuantizedLinear
|
||||
from ipex_llm.utils.common.log4Error import invalidInputError
|
||||
|
||||
|
||||
def module_optimization(func) -> torch.nn.Module:
|
||||
|
|
@ -134,6 +135,14 @@ def convert_forward(m, target_m, new_forward):
|
|||
convert_forward(sub_m, target_m, new_forward)
|
||||
|
||||
|
||||
def general_convert(m, target_m, new_func, func_name="forward"):
|
||||
if isinstance(m, target_m):
|
||||
bound_method = new_func.__get__(m, m.__class__)
|
||||
setattr(m, func_name, bound_method)
|
||||
for _, sub_m in m.named_children():
|
||||
general_convert(sub_m, target_m, new_func, func_name)
|
||||
|
||||
|
||||
def optimize_llm(model: torch.nn.Module):
|
||||
if model.config.model_type == "llama":
|
||||
from ipex_llm.transformers.npu_models.llama import merge_qkv
|
||||
|
|
@ -298,10 +307,32 @@ def generate(
|
|||
generation_config: Optional[GenerationConfig] = None,
|
||||
logits_processor: Optional[LogitsProcessorList] = None,
|
||||
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
||||
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]]=None,
|
||||
prefix_allowed_tokens_fn: Optional[Callable] = None,
|
||||
synced_gpus: Optional[bool] = None,
|
||||
assistant_model: Optional["PreTrainedModel"] = None,
|
||||
streamer: Optional["BaseStreamer"] = None,
|
||||
negative_prompt_ids: Optional[torch.Tensor] = None,
|
||||
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
simple = kwargs.pop("simple", True)
|
||||
if simple:
|
||||
return simple_generate(self, inputs=inputs, streamer=streamer, **kwargs)
|
||||
else:
|
||||
return self.original_generate(inputs=inputs, generation_config=generation_config,
|
||||
logits_processor=logits_processor,
|
||||
stopping_criteria=stopping_criteria,
|
||||
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
|
||||
synced_gpus=synced_gpus, assistant_model=assistant_model,
|
||||
streamer=streamer, negative_prompt_ids=negative_prompt_ids,
|
||||
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
||||
**kwargs)
|
||||
|
||||
|
||||
def simple_generate(
|
||||
self,
|
||||
inputs: Optional[torch.Tensor] = None,
|
||||
streamer: Optional["BaseStreamer"] = None,
|
||||
**kwargs,
|
||||
):
|
||||
# if do_print=True, output timing message
|
||||
|
|
@ -407,10 +438,64 @@ def optimize_llm_single_process(
|
|||
model.model_ptr = model_ptr
|
||||
model.save_directory = save_directory
|
||||
model.vocab_size = model.config.vocab_size
|
||||
model.logits_buffer = torch.empty(1, 1, model.vocab_size, dtype=torch.float32)
|
||||
except:
|
||||
invalidInputError(False,
|
||||
"False to InitLLMPipeline.")
|
||||
# patch model forward
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
general_convert(model, PreTrainedModel, prepare_input_ids, "prepare_inputs_for_generation")
|
||||
general_convert(model, PreTrainedModel, causal_lm_forward)
|
||||
# patch generate function
|
||||
import types
|
||||
model.original_generate = model.generate
|
||||
model.generate = types.MethodType(generate, model)
|
||||
return model
|
||||
|
||||
|
||||
def prepare_input_ids(
|
||||
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs):
|
||||
if past_key_values is not None: # kvcache
|
||||
input_ids = input_ids[:, -1]
|
||||
else: # prefill, reset the model here
|
||||
from .npu_llm_cpp import reset
|
||||
reset(self.model_ptr)
|
||||
model_inputs = {
|
||||
"input_ids": input_ids
|
||||
}
|
||||
return model_inputs
|
||||
|
||||
|
||||
def causal_lm_forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
from .npu_llm_cpp import run_prefill_with_logits, run_decode_with_logits
|
||||
if isinstance(input_ids[0], torch.Tensor):
|
||||
input_list = input_ids[0].flatten().tolist()
|
||||
else:
|
||||
input_list = input_ids[0]
|
||||
input_length = len(input_list)
|
||||
if input_length > 1:
|
||||
logits = run_prefill_with_logits(self.model_ptr, input_list,
|
||||
self.logits_buffer, self.vocab_size)
|
||||
else:
|
||||
logits = run_decode_with_logits(self.model_ptr, input_list[0],
|
||||
self.logits_buffer, self.vocab_size)
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=None,
|
||||
logits=logits,
|
||||
past_key_values=1, # just an indicator
|
||||
hidden_states=None,
|
||||
attentions=None,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -60,6 +60,14 @@ _lib.llm_sample_token.restype = ctypes.c_int
|
|||
_lib.reset.argtypes = [ctypes.c_void_p]
|
||||
_lib.reset.restype = None
|
||||
|
||||
_lib.run_prefill_with_logits.argtypes = [ctypes.c_void_p, ctypes.POINTER(ctypes.c_int),
|
||||
ctypes.c_int, ctypes.POINTER(ctypes.c_float), ctypes.c_int]
|
||||
_lib.run_prefill_with_logits.restype = None
|
||||
|
||||
_lib.run_decode_with_logits.argtypes = [ctypes.c_void_p, ctypes.c_int,
|
||||
ctypes.POINTER(ctypes.c_float), ctypes.c_int]
|
||||
_lib.run_decode_with_logits.restype = None
|
||||
|
||||
|
||||
def load_model_from_file(model_dir: str):
|
||||
return _lib.load_model_from_file(model_dir.encode('utf-8'))
|
||||
|
|
@ -79,5 +87,21 @@ def run_decode(model_ptr, input_id, vocab_size):
|
|||
return new_token
|
||||
|
||||
|
||||
def run_prefill_with_logits(model_ptr, input_ids, logits, vocab_size):
|
||||
input_ptr = (ctypes.c_int32 * len(input_ids))(*input_ids)
|
||||
input_len = len(input_ids)
|
||||
logits_ptr = logits.data.data_ptr()
|
||||
logits_ptr = ctypes.cast(logits_ptr, ctypes.POINTER(ctypes.c_float))
|
||||
_lib.run_prefill_with_logits(model_ptr, input_ptr, input_len, logits_ptr, vocab_size)
|
||||
return logits
|
||||
|
||||
|
||||
def run_decode_with_logits(model_ptr, input_id, logits, vocab_size):
|
||||
logits_ptr = logits.data.data_ptr()
|
||||
logits_ptr = ctypes.cast(logits_ptr, ctypes.POINTER(ctypes.c_float))
|
||||
_lib.run_decode_with_logits(model_ptr, input_id, logits_ptr, vocab_size)
|
||||
return logits
|
||||
|
||||
|
||||
def reset(model_ptr):
|
||||
_lib.reset(model_ptr)
|
||||
|
|
|
|||
Loading…
Reference in a new issue