Support hf generate (#12477)
* generate * style * update * remove timing * style * style * combine generate api * simple in kwargs
This commit is contained in:
		
							parent
							
								
									ef4028ac2d
								
							
						
					
					
						commit
						7ff4533b39
					
				
					 2 changed files with 114 additions and 5 deletions
				
			
		| 
						 | 
				
			
			@ -14,16 +14,17 @@
 | 
			
		|||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
from ipex_llm.utils.common.log4Error import invalidInputError
 | 
			
		||||
import os
 | 
			
		||||
import time
 | 
			
		||||
import torch
 | 
			
		||||
import importlib
 | 
			
		||||
from ipex_llm.transformers.npu_models.linear import QuantizedLinear
 | 
			
		||||
import time
 | 
			
		||||
from typing import Callable, List, Optional
 | 
			
		||||
from typing import Callable, List, Optional, Union, Tuple
 | 
			
		||||
from transformers import GenerationConfig, \
 | 
			
		||||
    LogitsProcessorList, StoppingCriteriaList
 | 
			
		||||
from transformers.modeling_outputs import CausalLMOutputWithPast
 | 
			
		||||
from ipex_llm.transformers.utils import module_name_process
 | 
			
		||||
from ipex_llm.transformers.npu_models.linear import QuantizedLinear
 | 
			
		||||
from ipex_llm.utils.common.log4Error import invalidInputError
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def module_optimization(func) -> torch.nn.Module:
 | 
			
		||||
| 
						 | 
				
			
			@ -134,6 +135,14 @@ def convert_forward(m, target_m, new_forward):
 | 
			
		|||
        convert_forward(sub_m, target_m, new_forward)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def general_convert(m, target_m, new_func, func_name="forward"):
 | 
			
		||||
    if isinstance(m, target_m):
 | 
			
		||||
        bound_method = new_func.__get__(m, m.__class__)
 | 
			
		||||
        setattr(m, func_name, bound_method)
 | 
			
		||||
    for _, sub_m in m.named_children():
 | 
			
		||||
        general_convert(sub_m, target_m, new_func, func_name)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def optimize_llm(model: torch.nn.Module):
 | 
			
		||||
    if model.config.model_type == "llama":
 | 
			
		||||
        from ipex_llm.transformers.npu_models.llama import merge_qkv
 | 
			
		||||
| 
						 | 
				
			
			@ -298,10 +307,32 @@ def generate(
 | 
			
		|||
    generation_config: Optional[GenerationConfig] = None,
 | 
			
		||||
    logits_processor: Optional[LogitsProcessorList] = None,
 | 
			
		||||
    stopping_criteria: Optional[StoppingCriteriaList] = None,
 | 
			
		||||
    prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]]=None,
 | 
			
		||||
    prefix_allowed_tokens_fn: Optional[Callable] = None,
 | 
			
		||||
    synced_gpus: Optional[bool] = None,
 | 
			
		||||
    assistant_model: Optional["PreTrainedModel"] = None,
 | 
			
		||||
    streamer: Optional["BaseStreamer"] = None,
 | 
			
		||||
    negative_prompt_ids: Optional[torch.Tensor] = None,
 | 
			
		||||
    negative_prompt_attention_mask: Optional[torch.Tensor] = None,
 | 
			
		||||
    **kwargs,
 | 
			
		||||
):
 | 
			
		||||
    simple = kwargs.pop("simple", True)
 | 
			
		||||
    if simple:
 | 
			
		||||
        return simple_generate(self, inputs=inputs, streamer=streamer, **kwargs)
 | 
			
		||||
    else:
 | 
			
		||||
        return self.original_generate(inputs=inputs, generation_config=generation_config,
 | 
			
		||||
                                      logits_processor=logits_processor,
 | 
			
		||||
                                      stopping_criteria=stopping_criteria,
 | 
			
		||||
                                      prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
 | 
			
		||||
                                      synced_gpus=synced_gpus, assistant_model=assistant_model,
 | 
			
		||||
                                      streamer=streamer, negative_prompt_ids=negative_prompt_ids,
 | 
			
		||||
                                      negative_prompt_attention_mask=negative_prompt_attention_mask,
 | 
			
		||||
                                      **kwargs)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def simple_generate(
 | 
			
		||||
    self,
 | 
			
		||||
    inputs: Optional[torch.Tensor] = None,
 | 
			
		||||
    streamer: Optional["BaseStreamer"] = None,
 | 
			
		||||
    **kwargs,
 | 
			
		||||
):
 | 
			
		||||
    # if do_print=True, output timing message
 | 
			
		||||
| 
						 | 
				
			
			@ -407,10 +438,64 @@ def optimize_llm_single_process(
 | 
			
		|||
        model.model_ptr = model_ptr
 | 
			
		||||
        model.save_directory = save_directory
 | 
			
		||||
        model.vocab_size = model.config.vocab_size
 | 
			
		||||
        model.logits_buffer = torch.empty(1, 1, model.vocab_size, dtype=torch.float32)
 | 
			
		||||
    except:
 | 
			
		||||
        invalidInputError(False,
 | 
			
		||||
                          "False to InitLLMPipeline.")
 | 
			
		||||
    # patch model forward
 | 
			
		||||
    from transformers.modeling_utils import PreTrainedModel
 | 
			
		||||
    general_convert(model, PreTrainedModel, prepare_input_ids, "prepare_inputs_for_generation")
 | 
			
		||||
    general_convert(model, PreTrainedModel, causal_lm_forward)
 | 
			
		||||
    # patch generate function
 | 
			
		||||
    import types
 | 
			
		||||
    model.original_generate = model.generate
 | 
			
		||||
    model.generate = types.MethodType(generate, model)
 | 
			
		||||
    return model
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def prepare_input_ids(
 | 
			
		||||
        self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs):
 | 
			
		||||
    if past_key_values is not None:  # kvcache
 | 
			
		||||
        input_ids = input_ids[:, -1]
 | 
			
		||||
    else:  # prefill, reset the model here
 | 
			
		||||
        from .npu_llm_cpp import reset
 | 
			
		||||
        reset(self.model_ptr)
 | 
			
		||||
    model_inputs = {
 | 
			
		||||
        "input_ids": input_ids
 | 
			
		||||
    }
 | 
			
		||||
    return model_inputs
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def causal_lm_forward(
 | 
			
		||||
    self,
 | 
			
		||||
    input_ids: torch.LongTensor = None,
 | 
			
		||||
    attention_mask: Optional[torch.Tensor] = None,
 | 
			
		||||
    position_ids: Optional[torch.LongTensor] = None,
 | 
			
		||||
    past_key_values: Optional[List[torch.FloatTensor]] = None,
 | 
			
		||||
    inputs_embeds: Optional[torch.FloatTensor] = None,
 | 
			
		||||
    labels: Optional[torch.LongTensor] = None,
 | 
			
		||||
    use_cache: Optional[bool] = None,
 | 
			
		||||
    output_attentions: Optional[bool] = None,
 | 
			
		||||
    output_hidden_states: Optional[bool] = None,
 | 
			
		||||
    return_dict: Optional[bool] = None,
 | 
			
		||||
) -> Union[Tuple, CausalLMOutputWithPast]:
 | 
			
		||||
    from .npu_llm_cpp import run_prefill_with_logits, run_decode_with_logits
 | 
			
		||||
    if isinstance(input_ids[0], torch.Tensor):
 | 
			
		||||
        input_list = input_ids[0].flatten().tolist()
 | 
			
		||||
    else:
 | 
			
		||||
        input_list = input_ids[0]
 | 
			
		||||
    input_length = len(input_list)
 | 
			
		||||
    if input_length > 1:
 | 
			
		||||
        logits = run_prefill_with_logits(self.model_ptr, input_list,
 | 
			
		||||
                                         self.logits_buffer, self.vocab_size)
 | 
			
		||||
    else:
 | 
			
		||||
        logits = run_decode_with_logits(self.model_ptr, input_list[0],
 | 
			
		||||
                                        self.logits_buffer, self.vocab_size)
 | 
			
		||||
 | 
			
		||||
    return CausalLMOutputWithPast(
 | 
			
		||||
        loss=None,
 | 
			
		||||
        logits=logits,
 | 
			
		||||
        past_key_values=1,  # just an indicator
 | 
			
		||||
        hidden_states=None,
 | 
			
		||||
        attentions=None,
 | 
			
		||||
    )
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -60,6 +60,14 @@ _lib.llm_sample_token.restype = ctypes.c_int
 | 
			
		|||
_lib.reset.argtypes = [ctypes.c_void_p]
 | 
			
		||||
_lib.reset.restype = None
 | 
			
		||||
 | 
			
		||||
_lib.run_prefill_with_logits.argtypes = [ctypes.c_void_p, ctypes.POINTER(ctypes.c_int),
 | 
			
		||||
                                         ctypes.c_int, ctypes.POINTER(ctypes.c_float), ctypes.c_int]
 | 
			
		||||
_lib.run_prefill_with_logits.restype = None
 | 
			
		||||
 | 
			
		||||
_lib.run_decode_with_logits.argtypes = [ctypes.c_void_p, ctypes.c_int,
 | 
			
		||||
                                        ctypes.POINTER(ctypes.c_float), ctypes.c_int]
 | 
			
		||||
_lib.run_decode_with_logits.restype = None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def load_model_from_file(model_dir: str):
 | 
			
		||||
    return _lib.load_model_from_file(model_dir.encode('utf-8'))
 | 
			
		||||
| 
						 | 
				
			
			@ -79,5 +87,21 @@ def run_decode(model_ptr, input_id, vocab_size):
 | 
			
		|||
    return new_token
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def run_prefill_with_logits(model_ptr, input_ids, logits, vocab_size):
 | 
			
		||||
    input_ptr = (ctypes.c_int32 * len(input_ids))(*input_ids)
 | 
			
		||||
    input_len = len(input_ids)
 | 
			
		||||
    logits_ptr = logits.data.data_ptr()
 | 
			
		||||
    logits_ptr = ctypes.cast(logits_ptr, ctypes.POINTER(ctypes.c_float))
 | 
			
		||||
    _lib.run_prefill_with_logits(model_ptr, input_ptr, input_len, logits_ptr, vocab_size)
 | 
			
		||||
    return logits
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def run_decode_with_logits(model_ptr, input_id, logits, vocab_size):
 | 
			
		||||
    logits_ptr = logits.data.data_ptr()
 | 
			
		||||
    logits_ptr = ctypes.cast(logits_ptr, ctypes.POINTER(ctypes.c_float))
 | 
			
		||||
    _lib.run_decode_with_logits(model_ptr, input_id, logits_ptr, vocab_size)
 | 
			
		||||
    return logits
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def reset(model_ptr):
 | 
			
		||||
    _lib.reset(model_ptr)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue