From a69395f31f09de985b4030c832cc0f2c81022da3 Mon Sep 17 00:00:00 2001 From: Yuwen Hu <54161268+Oscilloscope98@users.noreply.github.com> Date: Mon, 18 Nov 2024 18:46:52 +0800 Subject: [PATCH] Support performance mode of GLM4 model (#12401) * Initial support of prepare generation args for transformers 445 * Small fix to chatglm4 model optimization * Small fix * fix glm4 position id * fix glm4 error * Small change in conditon & fix based on comments * Style fixes --------- Co-authored-by: cyita --- .../llm/src/ipex_llm/transformers/lookup.py | 21 ++- .../ipex_llm/transformers/models/chatglm4.py | 12 +- .../src/ipex_llm/transformers/speculative.py | 160 +++++++++++++++++- 3 files changed, 177 insertions(+), 16 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/lookup.py b/python/llm/src/ipex_llm/transformers/lookup.py index c5fe81d4..062dafb0 100644 --- a/python/llm/src/ipex_llm/transformers/lookup.py +++ b/python/llm/src/ipex_llm/transformers/lookup.py @@ -27,9 +27,11 @@ import time import copy import random import logging +import transformers from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteriaList from ipex_llm.transformers.speculative import greedy, deepmind_sample, logits_to_probs,\ - _crop_past_key_values, _prepare_generate_args, _non_cpu_ipex_verify, clear_benchmarks + _crop_past_key_values, _prepare_generate_args, _non_cpu_ipex_verify, clear_benchmarks,\ + _prepare_generate_args_4_45 from ipex_llm.utils.common import invalidInputError from ipex_llm.transformers.utils import get_xpu_device_type @@ -278,16 +280,21 @@ def lookup_generate(self, streamer: Optional["BaseStreamer"] = None, attention_mask=None, **sampling_kwargs): - input_ids, generation_config, logits_processor, stopping_criteria, \ - model_kwargs = _prepare_generate_args(self, inputs, generation_config, - **sampling_kwargs) + from packaging import version + trans_version = transformers.__version__ + + if version.parse(trans_version) >= version.parse("4.45.0"): + input_ids, generation_config, logits_processor, stopping_criteria, \ + model_kwargs = _prepare_generate_args_4_45(self, inputs, generation_config, + streamer, **sampling_kwargs) + else: + input_ids, generation_config, logits_processor, stopping_criteria, \ + model_kwargs = _prepare_generate_args(self, inputs, generation_config, + streamer, **sampling_kwargs) invalidInputError(input_ids.shape[0] == 1, "Prompt lookup is currently not supported with batch inference.") - if streamer is not None: - streamer.put(input_ids.cpu()) - device_name = get_xpu_device_type(input_ids) candidates_generator = PromptLookupCandidateGenerator( diff --git a/python/llm/src/ipex_llm/transformers/models/chatglm4.py b/python/llm/src/ipex_llm/transformers/models/chatglm4.py index e3ba6bdf..72ac00a1 100644 --- a/python/llm/src/ipex_llm/transformers/models/chatglm4.py +++ b/python/llm/src/ipex_llm/transformers/models/chatglm4.py @@ -76,9 +76,15 @@ def chatglm4_model_forward( if full_attention_mask is None: if (attention_mask is not None and not attention_mask.all()) or\ (past_key_values and seq_length != 1): - full_attention_mask = self.get_masks(input_ids, - past_key_values, - padding_mask=attention_mask) + if self.config.hidden_size == 4096: + # glm4-9b + full_attention_mask = self.get_masks(input_ids, + past_key_values, + padding_mask=attention_mask) + else: + full_attention_mask = self.get_masks(inputs_embeds, + past_key_values, + padding_mask=attention_mask) # ipex-llm changes begin # 1. replace `rotary_pos_emb` with `inv_freq` and `position_ids` diff --git a/python/llm/src/ipex_llm/transformers/speculative.py b/python/llm/src/ipex_llm/transformers/speculative.py index 4600e99f..1d3107c8 100644 --- a/python/llm/src/ipex_llm/transformers/speculative.py +++ b/python/llm/src/ipex_llm/transformers/speculative.py @@ -14,16 +14,34 @@ # limitations under the License. # # Some parts of this file is adapted from -# https://github.com/dilab-zju/self-speculative-decoding/blob/main/decoding.py and -# https://github.com/huggingface/transformers/blob/main/src/transformers/generation -# /utils.py +# https://github.com/huggingface/transformers/blob/main/src/transformers/generation/utils.py and +# https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/generation/utils.py +# which are licensed under Apache License 2.0: # +# Copyright 2020 The Google AI Language Team Authors, Facebook AI Research authors +# and The HuggingFace Inc. team. +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# and https://github.com/dilab-zju/self-speculative-decoding/blob/main/decoding.py import torch import time import os import copy import logging +import inspect import transformers from packaging import version from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union @@ -493,14 +511,17 @@ def _crop_past_key_values(self, past_key_values, new_cache_size, _enable_ipex=Fa for k, v in past_key_values ] elif self.config.model_type == "chatglm": - if self.config.num_layers == 40 and hasattr(self.config, 'rope_ratio'): + if isinstance(self.config.eos_token_id, list) and \ + not hasattr(self.transformer, "vision") and \ + self.config.num_layers in [28, 40]: + # glm4 models past_key_values = [ (k[:, :, :-(new_cache_size), :], v[:, :, :-(new_cache_size), :]) for k, v in past_key_values ] else: - # for chatglm, cache shape is [sl, bs, nh, hn] + # chatglm2 & chatglm3, cache shape is [sl, bs, nh, hn] past_key_values = [ (k[:-(new_cache_size), :, :, :], v[:-(new_cache_size), :, :, :]) @@ -631,6 +652,127 @@ def _prepare_generate_args(self, inputs, generation_config, streamer=None, **sam return input_ids, generation_config, logits_processor, stopping_criteria, model_kwargs +def _prepare_generate_args_4_45(self, inputs, generation_config, streamer=None, **kwargs): + # 1. Handle `generation_config` and kwargs that might update it, + # and validate the `.generate()` call + self._validate_model_class() + # Pull this out first, we only use it for stopping criteria + tokenizer = kwargs.pop("tokenizer", None) + generation_config, model_kwargs = self._prepare_generation_config(generation_config, **kwargs) + self._validate_model_kwargs(model_kwargs.copy()) + + # 2. Set generation parameters if not already defined + logits_processor = kwargs.pop("logits_processor", None) + stopping_criteria = kwargs.pop("stopping_criteria", None) + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + stopping_criteria = \ + stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + + accepts_attention_mask = \ + "attention_mask" in set(inspect.signature(self.forward).parameters.keys()) + requires_attention_mask = "encoder_outputs" not in model_kwargs + kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None + + # 3. Define model inputs + inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs( + inputs, generation_config.bos_token_id, model_kwargs + ) + batch_size = inputs_tensor.shape[0] + + device = inputs_tensor.device + self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=device) + + # decoder-only models must use left-padding for batched generation. + from transformers.utils import is_torchdynamo_compiling + if not self.config.is_encoder_decoder and not is_torchdynamo_compiling(): + # If `input_ids` was given, check if the last id in any sequence is `pad_token_id` + # Note: If using, `inputs_embeds` this check does not work + # because we want to be more hands-off. + if ( + generation_config._pad_token_tensor is not None + and batch_size > 1 + and len(inputs_tensor.shape) == 2 + and torch.sum(inputs_tensor[:, -1] == generation_config._pad_token_tensor) > 0 + ): + logger.warning( + "A decoder-only architecture is being used, but right-padding was detected! " + "For correct generation results, please set `padding_side='left'` " + "when initializing the tokenizer." + ) + else: + perf_mode = os.environ.get("IPEX_LLM_PERFORMANCE_MODE", None) + if perf_mode == "1": + error_str = "IPEX-LLM performance mode" + else: + error_str = "IPEX-LLM lookup or speculative generation" + invalidInputError(False, f"Encoder-decoder models are not supported now for {error_str}.") + + # 4. Define other model kwargs + # decoder-only models with inputs_embeds forwarding must use caching + # (otherwise we can't detect whether we are generating the first new token or not, + # and we only want to use the embeddings for the first new token) + if not self.config.is_encoder_decoder and model_input_name == "inputs_embeds": + model_kwargs["use_cache"] = True + else: + model_kwargs["use_cache"] = generation_config.use_cache + + if not kwargs_has_attention_mask and requires_attention_mask and accepts_attention_mask: + model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( + inputs_tensor, generation_config._pad_token_tensor, generation_config._eos_token_tensor + ) + elif kwargs_has_attention_mask: + if model_input_name == "input_ids" and len(model_kwargs["attention_mask"].shape) > 2: + invalidInputError(False, "`attention_mask` passed to `generate` must be 2D.") + + # 5. Prepare `input_ids` which will be used for auto-regressive generation + input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids") + + if generation_config.token_healing: + input_ids = self.heal_tokens(input_ids, tokenizer) + + if streamer is not None: + streamer.put(input_ids.cpu()) + + # 6. Prepare `max_length` depending on other stopping criteria. + input_ids_length = input_ids.shape[-1] + # skip due to individul max_new_token dealing + + # 7. Prepare the cache. + # skip + + # 8. determine generation mode + # skip + if streamer is not None and (generation_config.num_beams > 1): + invalidInputError( + False, + "`streamer` cannot be used with beam search (yet!). " + "Make sure that `num_beams` is set to 1." + ) + + # 9. prepare logits processors and stopping criteria + prefix_allowed_tokens_fn = kwargs.pop("prefix_allowed_tokens_fn", None) + prepared_logits_processor = self._get_logits_processor( + generation_config=generation_config, + input_ids_seq_length=input_ids_length, + encoder_input_ids=inputs_tensor, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, + logits_processor=logits_processor, + device=inputs_tensor.device, + model_kwargs=model_kwargs, + negative_prompt_ids=None, + negative_prompt_attention_mask=None, + ) + prepared_stopping_criteria = self._get_stopping_criteria( + generation_config=generation_config, + stopping_criteria=stopping_criteria, + tokenizer=tokenizer, + **kwargs + ) + + return input_ids, generation_config, prepared_logits_processor, prepared_stopping_criteria,\ + model_kwargs + + def _non_cpu_ipex_verify(self, verify_input_ids, past_key_values, cur_attention_mask=None, return_dict=True, use_cache=True): forward_args = { @@ -643,7 +785,13 @@ def _non_cpu_ipex_verify(self, verify_input_ids, past_key_values, cur_attention_ forward_args["attention_mask"] = cur_attention_mask if self.config.model_type == "chatglm": - past_key_value_len = past_key_values[0][0].shape[0] + if isinstance(self.config.eos_token_id, list) and not hasattr(self.transformer, "vision") \ + and self.config.num_layers in [28, 40]: + # glm4 models + past_key_value_len = past_key_values[0][0].shape[2] + else: + # chatglm2 and chatglm3 + past_key_value_len = past_key_values[0][0].shape[0] position_ids = torch.arange(verify_input_ids.shape[1], dtype=torch.long, device=verify_input_ids.device) position_ids = position_ids.unsqueeze(0).repeat(1, 1) + past_key_value_len