Support performance mode of GLM4 model (#12401)
* Initial support of prepare generation args for transformers 445 * Small fix to chatglm4 model optimization * Small fix * fix glm4 position id * fix glm4 error * Small change in conditon & fix based on comments * Style fixes --------- Co-authored-by: cyita <yitastudy@gmail.com>
This commit is contained in:
		
							parent
							
								
									d2c821d458
								
							
						
					
					
						commit
						a69395f31f
					
				
					 3 changed files with 177 additions and 16 deletions
				
			
		| 
						 | 
				
			
			@ -27,9 +27,11 @@ import time
 | 
			
		|||
import copy
 | 
			
		||||
import random
 | 
			
		||||
import logging
 | 
			
		||||
import transformers
 | 
			
		||||
from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteriaList
 | 
			
		||||
from ipex_llm.transformers.speculative import greedy, deepmind_sample, logits_to_probs,\
 | 
			
		||||
    _crop_past_key_values, _prepare_generate_args, _non_cpu_ipex_verify, clear_benchmarks
 | 
			
		||||
    _crop_past_key_values, _prepare_generate_args, _non_cpu_ipex_verify, clear_benchmarks,\
 | 
			
		||||
    _prepare_generate_args_4_45
 | 
			
		||||
from ipex_llm.utils.common import invalidInputError
 | 
			
		||||
from ipex_llm.transformers.utils import get_xpu_device_type
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -278,16 +280,21 @@ def lookup_generate(self,
 | 
			
		|||
                    streamer: Optional["BaseStreamer"] = None,
 | 
			
		||||
                    attention_mask=None,
 | 
			
		||||
                    **sampling_kwargs):
 | 
			
		||||
    input_ids, generation_config, logits_processor, stopping_criteria, \
 | 
			
		||||
        model_kwargs = _prepare_generate_args(self, inputs, generation_config,
 | 
			
		||||
                                              **sampling_kwargs)
 | 
			
		||||
    from packaging import version
 | 
			
		||||
    trans_version = transformers.__version__
 | 
			
		||||
 | 
			
		||||
    if version.parse(trans_version) >= version.parse("4.45.0"):
 | 
			
		||||
        input_ids, generation_config, logits_processor, stopping_criteria, \
 | 
			
		||||
            model_kwargs = _prepare_generate_args_4_45(self, inputs, generation_config,
 | 
			
		||||
                                                       streamer, **sampling_kwargs)
 | 
			
		||||
    else:
 | 
			
		||||
        input_ids, generation_config, logits_processor, stopping_criteria, \
 | 
			
		||||
            model_kwargs = _prepare_generate_args(self, inputs, generation_config,
 | 
			
		||||
                                                  streamer, **sampling_kwargs)
 | 
			
		||||
 | 
			
		||||
    invalidInputError(input_ids.shape[0] == 1,
 | 
			
		||||
                      "Prompt lookup is currently not supported with batch inference.")
 | 
			
		||||
 | 
			
		||||
    if streamer is not None:
 | 
			
		||||
        streamer.put(input_ids.cpu())
 | 
			
		||||
 | 
			
		||||
    device_name = get_xpu_device_type(input_ids)
 | 
			
		||||
 | 
			
		||||
    candidates_generator = PromptLookupCandidateGenerator(
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -76,9 +76,15 @@ def chatglm4_model_forward(
 | 
			
		|||
    if full_attention_mask is None:
 | 
			
		||||
        if (attention_mask is not None and not attention_mask.all()) or\
 | 
			
		||||
                (past_key_values and seq_length != 1):
 | 
			
		||||
            full_attention_mask = self.get_masks(input_ids,
 | 
			
		||||
                                                 past_key_values,
 | 
			
		||||
                                                 padding_mask=attention_mask)
 | 
			
		||||
            if self.config.hidden_size == 4096:
 | 
			
		||||
                # glm4-9b
 | 
			
		||||
                full_attention_mask = self.get_masks(input_ids,
 | 
			
		||||
                                                     past_key_values,
 | 
			
		||||
                                                     padding_mask=attention_mask)
 | 
			
		||||
            else:
 | 
			
		||||
                full_attention_mask = self.get_masks(inputs_embeds,
 | 
			
		||||
                                                     past_key_values,
 | 
			
		||||
                                                     padding_mask=attention_mask)
 | 
			
		||||
 | 
			
		||||
    # ipex-llm changes begin
 | 
			
		||||
    # 1. replace `rotary_pos_emb` with `inv_freq` and `position_ids`
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -14,16 +14,34 @@
 | 
			
		|||
# limitations under the License.
 | 
			
		||||
#
 | 
			
		||||
# Some parts of this file is adapted from
 | 
			
		||||
# https://github.com/dilab-zju/self-speculative-decoding/blob/main/decoding.py and
 | 
			
		||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/generation
 | 
			
		||||
# /utils.py
 | 
			
		||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/generation/utils.py and
 | 
			
		||||
# https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/generation/utils.py
 | 
			
		||||
# which are licensed under Apache License 2.0:
 | 
			
		||||
#
 | 
			
		||||
# Copyright 2020 The Google AI Language Team Authors, Facebook AI Research authors
 | 
			
		||||
# and The HuggingFace Inc. team.
 | 
			
		||||
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
#
 | 
			
		||||
# and https://github.com/dilab-zju/self-speculative-decoding/blob/main/decoding.py
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import time
 | 
			
		||||
import os
 | 
			
		||||
import copy
 | 
			
		||||
import logging
 | 
			
		||||
import inspect
 | 
			
		||||
import transformers
 | 
			
		||||
from packaging import version
 | 
			
		||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
 | 
			
		||||
| 
						 | 
				
			
			@ -493,14 +511,17 @@ def _crop_past_key_values(self, past_key_values, new_cache_size, _enable_ipex=Fa
 | 
			
		|||
                for k, v in past_key_values
 | 
			
		||||
            ]
 | 
			
		||||
        elif self.config.model_type == "chatglm":
 | 
			
		||||
            if self.config.num_layers == 40 and hasattr(self.config, 'rope_ratio'):
 | 
			
		||||
            if isinstance(self.config.eos_token_id, list) and \
 | 
			
		||||
                    not hasattr(self.transformer, "vision") and \
 | 
			
		||||
                    self.config.num_layers in [28, 40]:
 | 
			
		||||
                # glm4 models
 | 
			
		||||
                past_key_values = [
 | 
			
		||||
                    (k[:, :, :-(new_cache_size), :],
 | 
			
		||||
                        v[:, :, :-(new_cache_size), :])
 | 
			
		||||
                    for k, v in past_key_values
 | 
			
		||||
                ]
 | 
			
		||||
            else:
 | 
			
		||||
                # for chatglm, cache shape is [sl, bs, nh, hn]
 | 
			
		||||
                # chatglm2 & chatglm3, cache shape is [sl, bs, nh, hn]
 | 
			
		||||
                past_key_values = [
 | 
			
		||||
                    (k[:-(new_cache_size), :, :, :],
 | 
			
		||||
                        v[:-(new_cache_size), :, :, :])
 | 
			
		||||
| 
						 | 
				
			
			@ -631,6 +652,127 @@ def _prepare_generate_args(self, inputs, generation_config, streamer=None, **sam
 | 
			
		|||
    return input_ids, generation_config, logits_processor, stopping_criteria, model_kwargs
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _prepare_generate_args_4_45(self, inputs, generation_config, streamer=None, **kwargs):
 | 
			
		||||
    # 1. Handle `generation_config` and kwargs that might update it,
 | 
			
		||||
    # and validate the `.generate()` call
 | 
			
		||||
    self._validate_model_class()
 | 
			
		||||
    # Pull this out first, we only use it for stopping criteria
 | 
			
		||||
    tokenizer = kwargs.pop("tokenizer", None)
 | 
			
		||||
    generation_config, model_kwargs = self._prepare_generation_config(generation_config, **kwargs)
 | 
			
		||||
    self._validate_model_kwargs(model_kwargs.copy())
 | 
			
		||||
 | 
			
		||||
    # 2. Set generation parameters if not already defined
 | 
			
		||||
    logits_processor = kwargs.pop("logits_processor", None)
 | 
			
		||||
    stopping_criteria = kwargs.pop("stopping_criteria", None)
 | 
			
		||||
    logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
 | 
			
		||||
    stopping_criteria = \
 | 
			
		||||
        stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
 | 
			
		||||
 | 
			
		||||
    accepts_attention_mask = \
 | 
			
		||||
        "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
 | 
			
		||||
    requires_attention_mask = "encoder_outputs" not in model_kwargs
 | 
			
		||||
    kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
 | 
			
		||||
 | 
			
		||||
    # 3. Define model inputs
 | 
			
		||||
    inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
 | 
			
		||||
        inputs, generation_config.bos_token_id, model_kwargs
 | 
			
		||||
    )
 | 
			
		||||
    batch_size = inputs_tensor.shape[0]
 | 
			
		||||
 | 
			
		||||
    device = inputs_tensor.device
 | 
			
		||||
    self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=device)
 | 
			
		||||
 | 
			
		||||
    # decoder-only models must use left-padding for batched generation.
 | 
			
		||||
    from transformers.utils import is_torchdynamo_compiling
 | 
			
		||||
    if not self.config.is_encoder_decoder and not is_torchdynamo_compiling():
 | 
			
		||||
        # If `input_ids` was given, check if the last id in any sequence is `pad_token_id`
 | 
			
		||||
        # Note: If using, `inputs_embeds` this check does not work
 | 
			
		||||
        # because we want to be more hands-off.
 | 
			
		||||
        if (
 | 
			
		||||
            generation_config._pad_token_tensor is not None
 | 
			
		||||
            and batch_size > 1
 | 
			
		||||
            and len(inputs_tensor.shape) == 2
 | 
			
		||||
            and torch.sum(inputs_tensor[:, -1] == generation_config._pad_token_tensor) > 0
 | 
			
		||||
        ):
 | 
			
		||||
            logger.warning(
 | 
			
		||||
                "A decoder-only architecture is being used, but right-padding was detected! "
 | 
			
		||||
                "For correct generation results, please set `padding_side='left'` "
 | 
			
		||||
                "when initializing the tokenizer."
 | 
			
		||||
            )
 | 
			
		||||
    else:
 | 
			
		||||
        perf_mode = os.environ.get("IPEX_LLM_PERFORMANCE_MODE", None)
 | 
			
		||||
        if perf_mode == "1":
 | 
			
		||||
            error_str = "IPEX-LLM performance mode"
 | 
			
		||||
        else:
 | 
			
		||||
            error_str = "IPEX-LLM lookup or speculative generation"
 | 
			
		||||
        invalidInputError(False, f"Encoder-decoder models are not supported now for {error_str}.")
 | 
			
		||||
 | 
			
		||||
    # 4. Define other model kwargs
 | 
			
		||||
    # decoder-only models with inputs_embeds forwarding must use caching
 | 
			
		||||
    # (otherwise we can't detect whether we are generating the first new token or not,
 | 
			
		||||
    # and we only want to use the embeddings for the first new token)
 | 
			
		||||
    if not self.config.is_encoder_decoder and model_input_name == "inputs_embeds":
 | 
			
		||||
        model_kwargs["use_cache"] = True
 | 
			
		||||
    else:
 | 
			
		||||
        model_kwargs["use_cache"] = generation_config.use_cache
 | 
			
		||||
 | 
			
		||||
    if not kwargs_has_attention_mask and requires_attention_mask and accepts_attention_mask:
 | 
			
		||||
        model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
 | 
			
		||||
            inputs_tensor, generation_config._pad_token_tensor, generation_config._eos_token_tensor
 | 
			
		||||
        )
 | 
			
		||||
    elif kwargs_has_attention_mask:
 | 
			
		||||
        if model_input_name == "input_ids" and len(model_kwargs["attention_mask"].shape) > 2:
 | 
			
		||||
            invalidInputError(False, "`attention_mask` passed to `generate` must be 2D.")
 | 
			
		||||
 | 
			
		||||
    # 5. Prepare `input_ids` which will be used for auto-regressive generation
 | 
			
		||||
    input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")
 | 
			
		||||
 | 
			
		||||
    if generation_config.token_healing:
 | 
			
		||||
        input_ids = self.heal_tokens(input_ids, tokenizer)
 | 
			
		||||
 | 
			
		||||
    if streamer is not None:
 | 
			
		||||
        streamer.put(input_ids.cpu())
 | 
			
		||||
 | 
			
		||||
    # 6. Prepare `max_length` depending on other stopping criteria.
 | 
			
		||||
    input_ids_length = input_ids.shape[-1]
 | 
			
		||||
    # skip due to individul max_new_token dealing
 | 
			
		||||
 | 
			
		||||
    # 7. Prepare the cache.
 | 
			
		||||
    # skip
 | 
			
		||||
 | 
			
		||||
    # 8. determine generation mode
 | 
			
		||||
    # skip
 | 
			
		||||
    if streamer is not None and (generation_config.num_beams > 1):
 | 
			
		||||
        invalidInputError(
 | 
			
		||||
            False,
 | 
			
		||||
            "`streamer` cannot be used with beam search (yet!). "
 | 
			
		||||
            "Make sure that `num_beams` is set to 1."
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    # 9. prepare logits processors and stopping criteria
 | 
			
		||||
    prefix_allowed_tokens_fn = kwargs.pop("prefix_allowed_tokens_fn", None)
 | 
			
		||||
    prepared_logits_processor = self._get_logits_processor(
 | 
			
		||||
        generation_config=generation_config,
 | 
			
		||||
        input_ids_seq_length=input_ids_length,
 | 
			
		||||
        encoder_input_ids=inputs_tensor,
 | 
			
		||||
        prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
 | 
			
		||||
        logits_processor=logits_processor,
 | 
			
		||||
        device=inputs_tensor.device,
 | 
			
		||||
        model_kwargs=model_kwargs,
 | 
			
		||||
        negative_prompt_ids=None,
 | 
			
		||||
        negative_prompt_attention_mask=None,
 | 
			
		||||
    )
 | 
			
		||||
    prepared_stopping_criteria = self._get_stopping_criteria(
 | 
			
		||||
        generation_config=generation_config,
 | 
			
		||||
        stopping_criteria=stopping_criteria,
 | 
			
		||||
        tokenizer=tokenizer,
 | 
			
		||||
        **kwargs
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    return input_ids, generation_config, prepared_logits_processor, prepared_stopping_criteria,\
 | 
			
		||||
        model_kwargs
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _non_cpu_ipex_verify(self, verify_input_ids, past_key_values, cur_attention_mask=None,
 | 
			
		||||
                         return_dict=True, use_cache=True):
 | 
			
		||||
    forward_args = {
 | 
			
		||||
| 
						 | 
				
			
			@ -643,7 +785,13 @@ def _non_cpu_ipex_verify(self, verify_input_ids, past_key_values, cur_attention_
 | 
			
		|||
        forward_args["attention_mask"] = cur_attention_mask
 | 
			
		||||
 | 
			
		||||
    if self.config.model_type == "chatglm":
 | 
			
		||||
        past_key_value_len = past_key_values[0][0].shape[0]
 | 
			
		||||
        if isinstance(self.config.eos_token_id, list) and not hasattr(self.transformer, "vision") \
 | 
			
		||||
                and self.config.num_layers in [28, 40]:
 | 
			
		||||
            # glm4 models
 | 
			
		||||
            past_key_value_len = past_key_values[0][0].shape[2]
 | 
			
		||||
        else:
 | 
			
		||||
            # chatglm2 and chatglm3
 | 
			
		||||
            past_key_value_len = past_key_values[0][0].shape[0]
 | 
			
		||||
        position_ids = torch.arange(verify_input_ids.shape[1], dtype=torch.long,
 | 
			
		||||
                                    device=verify_input_ids.device)
 | 
			
		||||
        position_ids = position_ids.unsqueeze(0).repeat(1, 1) + past_key_value_len
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue