Support performance mode of GLM4 model (#12401)

* Initial support of prepare generation args for transformers 445

* Small fix to chatglm4 model optimization

* Small fix

* fix glm4 position id

* fix glm4 error

* Small change in conditon & fix based on comments

* Style fixes

---------

Co-authored-by: cyita <yitastudy@gmail.com>
This commit is contained in:
Yuwen Hu 2024-11-18 18:46:52 +08:00 committed by GitHub
parent d2c821d458
commit a69395f31f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 177 additions and 16 deletions

View file

@ -27,9 +27,11 @@ import time
import copy
import random
import logging
import transformers
from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteriaList
from ipex_llm.transformers.speculative import greedy, deepmind_sample, logits_to_probs,\
_crop_past_key_values, _prepare_generate_args, _non_cpu_ipex_verify, clear_benchmarks
_crop_past_key_values, _prepare_generate_args, _non_cpu_ipex_verify, clear_benchmarks,\
_prepare_generate_args_4_45
from ipex_llm.utils.common import invalidInputError
from ipex_llm.transformers.utils import get_xpu_device_type
@ -278,16 +280,21 @@ def lookup_generate(self,
streamer: Optional["BaseStreamer"] = None,
attention_mask=None,
**sampling_kwargs):
from packaging import version
trans_version = transformers.__version__
if version.parse(trans_version) >= version.parse("4.45.0"):
input_ids, generation_config, logits_processor, stopping_criteria, \
model_kwargs = _prepare_generate_args_4_45(self, inputs, generation_config,
streamer, **sampling_kwargs)
else:
input_ids, generation_config, logits_processor, stopping_criteria, \
model_kwargs = _prepare_generate_args(self, inputs, generation_config,
**sampling_kwargs)
streamer, **sampling_kwargs)
invalidInputError(input_ids.shape[0] == 1,
"Prompt lookup is currently not supported with batch inference.")
if streamer is not None:
streamer.put(input_ids.cpu())
device_name = get_xpu_device_type(input_ids)
candidates_generator = PromptLookupCandidateGenerator(

View file

@ -76,9 +76,15 @@ def chatglm4_model_forward(
if full_attention_mask is None:
if (attention_mask is not None and not attention_mask.all()) or\
(past_key_values and seq_length != 1):
if self.config.hidden_size == 4096:
# glm4-9b
full_attention_mask = self.get_masks(input_ids,
past_key_values,
padding_mask=attention_mask)
else:
full_attention_mask = self.get_masks(inputs_embeds,
past_key_values,
padding_mask=attention_mask)
# ipex-llm changes begin
# 1. replace `rotary_pos_emb` with `inv_freq` and `position_ids`

View file

@ -14,16 +14,34 @@
# limitations under the License.
#
# Some parts of this file is adapted from
# https://github.com/dilab-zju/self-speculative-decoding/blob/main/decoding.py and
# https://github.com/huggingface/transformers/blob/main/src/transformers/generation
# /utils.py
# https://github.com/huggingface/transformers/blob/main/src/transformers/generation/utils.py and
# https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/generation/utils.py
# which are licensed under Apache License 2.0:
#
# Copyright 2020 The Google AI Language Team Authors, Facebook AI Research authors
# and The HuggingFace Inc. team.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# and https://github.com/dilab-zju/self-speculative-decoding/blob/main/decoding.py
import torch
import time
import os
import copy
import logging
import inspect
import transformers
from packaging import version
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
@ -493,14 +511,17 @@ def _crop_past_key_values(self, past_key_values, new_cache_size, _enable_ipex=Fa
for k, v in past_key_values
]
elif self.config.model_type == "chatglm":
if self.config.num_layers == 40 and hasattr(self.config, 'rope_ratio'):
if isinstance(self.config.eos_token_id, list) and \
not hasattr(self.transformer, "vision") and \
self.config.num_layers in [28, 40]:
# glm4 models
past_key_values = [
(k[:, :, :-(new_cache_size), :],
v[:, :, :-(new_cache_size), :])
for k, v in past_key_values
]
else:
# for chatglm, cache shape is [sl, bs, nh, hn]
# chatglm2 & chatglm3, cache shape is [sl, bs, nh, hn]
past_key_values = [
(k[:-(new_cache_size), :, :, :],
v[:-(new_cache_size), :, :, :])
@ -631,6 +652,127 @@ def _prepare_generate_args(self, inputs, generation_config, streamer=None, **sam
return input_ids, generation_config, logits_processor, stopping_criteria, model_kwargs
def _prepare_generate_args_4_45(self, inputs, generation_config, streamer=None, **kwargs):
# 1. Handle `generation_config` and kwargs that might update it,
# and validate the `.generate()` call
self._validate_model_class()
# Pull this out first, we only use it for stopping criteria
tokenizer = kwargs.pop("tokenizer", None)
generation_config, model_kwargs = self._prepare_generation_config(generation_config, **kwargs)
self._validate_model_kwargs(model_kwargs.copy())
# 2. Set generation parameters if not already defined
logits_processor = kwargs.pop("logits_processor", None)
stopping_criteria = kwargs.pop("stopping_criteria", None)
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = \
stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
accepts_attention_mask = \
"attention_mask" in set(inspect.signature(self.forward).parameters.keys())
requires_attention_mask = "encoder_outputs" not in model_kwargs
kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
# 3. Define model inputs
inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
inputs, generation_config.bos_token_id, model_kwargs
)
batch_size = inputs_tensor.shape[0]
device = inputs_tensor.device
self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=device)
# decoder-only models must use left-padding for batched generation.
from transformers.utils import is_torchdynamo_compiling
if not self.config.is_encoder_decoder and not is_torchdynamo_compiling():
# If `input_ids` was given, check if the last id in any sequence is `pad_token_id`
# Note: If using, `inputs_embeds` this check does not work
# because we want to be more hands-off.
if (
generation_config._pad_token_tensor is not None
and batch_size > 1
and len(inputs_tensor.shape) == 2
and torch.sum(inputs_tensor[:, -1] == generation_config._pad_token_tensor) > 0
):
logger.warning(
"A decoder-only architecture is being used, but right-padding was detected! "
"For correct generation results, please set `padding_side='left'` "
"when initializing the tokenizer."
)
else:
perf_mode = os.environ.get("IPEX_LLM_PERFORMANCE_MODE", None)
if perf_mode == "1":
error_str = "IPEX-LLM performance mode"
else:
error_str = "IPEX-LLM lookup or speculative generation"
invalidInputError(False, f"Encoder-decoder models are not supported now for {error_str}.")
# 4. Define other model kwargs
# decoder-only models with inputs_embeds forwarding must use caching
# (otherwise we can't detect whether we are generating the first new token or not,
# and we only want to use the embeddings for the first new token)
if not self.config.is_encoder_decoder and model_input_name == "inputs_embeds":
model_kwargs["use_cache"] = True
else:
model_kwargs["use_cache"] = generation_config.use_cache
if not kwargs_has_attention_mask and requires_attention_mask and accepts_attention_mask:
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
inputs_tensor, generation_config._pad_token_tensor, generation_config._eos_token_tensor
)
elif kwargs_has_attention_mask:
if model_input_name == "input_ids" and len(model_kwargs["attention_mask"].shape) > 2:
invalidInputError(False, "`attention_mask` passed to `generate` must be 2D.")
# 5. Prepare `input_ids` which will be used for auto-regressive generation
input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")
if generation_config.token_healing:
input_ids = self.heal_tokens(input_ids, tokenizer)
if streamer is not None:
streamer.put(input_ids.cpu())
# 6. Prepare `max_length` depending on other stopping criteria.
input_ids_length = input_ids.shape[-1]
# skip due to individul max_new_token dealing
# 7. Prepare the cache.
# skip
# 8. determine generation mode
# skip
if streamer is not None and (generation_config.num_beams > 1):
invalidInputError(
False,
"`streamer` cannot be used with beam search (yet!). "
"Make sure that `num_beams` is set to 1."
)
# 9. prepare logits processors and stopping criteria
prefix_allowed_tokens_fn = kwargs.pop("prefix_allowed_tokens_fn", None)
prepared_logits_processor = self._get_logits_processor(
generation_config=generation_config,
input_ids_seq_length=input_ids_length,
encoder_input_ids=inputs_tensor,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
logits_processor=logits_processor,
device=inputs_tensor.device,
model_kwargs=model_kwargs,
negative_prompt_ids=None,
negative_prompt_attention_mask=None,
)
prepared_stopping_criteria = self._get_stopping_criteria(
generation_config=generation_config,
stopping_criteria=stopping_criteria,
tokenizer=tokenizer,
**kwargs
)
return input_ids, generation_config, prepared_logits_processor, prepared_stopping_criteria,\
model_kwargs
def _non_cpu_ipex_verify(self, verify_input_ids, past_key_values, cur_attention_mask=None,
return_dict=True, use_cache=True):
forward_args = {
@ -643,6 +785,12 @@ def _non_cpu_ipex_verify(self, verify_input_ids, past_key_values, cur_attention_
forward_args["attention_mask"] = cur_attention_mask
if self.config.model_type == "chatglm":
if isinstance(self.config.eos_token_id, list) and not hasattr(self.transformer, "vision") \
and self.config.num_layers in [28, 40]:
# glm4 models
past_key_value_len = past_key_values[0][0].shape[2]
else:
# chatglm2 and chatglm3
past_key_value_len = past_key_values[0][0].shape[0]
position_ids = torch.arange(verify_input_ids.shape[1], dtype=torch.long,
device=verify_input_ids.device)