Support performance mode of GLM4 model (#12401)

* Initial support of prepare generation args for transformers 445

* Small fix to chatglm4 model optimization

* Small fix

* fix glm4 position id

* fix glm4 error

* Small change in conditon & fix based on comments

* Style fixes

---------

Co-authored-by: cyita <yitastudy@gmail.com>
This commit is contained in:
Yuwen Hu 2024-11-18 18:46:52 +08:00 committed by GitHub
parent d2c821d458
commit a69395f31f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 177 additions and 16 deletions

View file

@ -27,9 +27,11 @@ import time
import copy import copy
import random import random
import logging import logging
import transformers
from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteriaList from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteriaList
from ipex_llm.transformers.speculative import greedy, deepmind_sample, logits_to_probs,\ from ipex_llm.transformers.speculative import greedy, deepmind_sample, logits_to_probs,\
_crop_past_key_values, _prepare_generate_args, _non_cpu_ipex_verify, clear_benchmarks _crop_past_key_values, _prepare_generate_args, _non_cpu_ipex_verify, clear_benchmarks,\
_prepare_generate_args_4_45
from ipex_llm.utils.common import invalidInputError from ipex_llm.utils.common import invalidInputError
from ipex_llm.transformers.utils import get_xpu_device_type from ipex_llm.transformers.utils import get_xpu_device_type
@ -278,16 +280,21 @@ def lookup_generate(self,
streamer: Optional["BaseStreamer"] = None, streamer: Optional["BaseStreamer"] = None,
attention_mask=None, attention_mask=None,
**sampling_kwargs): **sampling_kwargs):
from packaging import version
trans_version = transformers.__version__
if version.parse(trans_version) >= version.parse("4.45.0"):
input_ids, generation_config, logits_processor, stopping_criteria, \
model_kwargs = _prepare_generate_args_4_45(self, inputs, generation_config,
streamer, **sampling_kwargs)
else:
input_ids, generation_config, logits_processor, stopping_criteria, \ input_ids, generation_config, logits_processor, stopping_criteria, \
model_kwargs = _prepare_generate_args(self, inputs, generation_config, model_kwargs = _prepare_generate_args(self, inputs, generation_config,
**sampling_kwargs) streamer, **sampling_kwargs)
invalidInputError(input_ids.shape[0] == 1, invalidInputError(input_ids.shape[0] == 1,
"Prompt lookup is currently not supported with batch inference.") "Prompt lookup is currently not supported with batch inference.")
if streamer is not None:
streamer.put(input_ids.cpu())
device_name = get_xpu_device_type(input_ids) device_name = get_xpu_device_type(input_ids)
candidates_generator = PromptLookupCandidateGenerator( candidates_generator = PromptLookupCandidateGenerator(

View file

@ -76,9 +76,15 @@ def chatglm4_model_forward(
if full_attention_mask is None: if full_attention_mask is None:
if (attention_mask is not None and not attention_mask.all()) or\ if (attention_mask is not None and not attention_mask.all()) or\
(past_key_values and seq_length != 1): (past_key_values and seq_length != 1):
if self.config.hidden_size == 4096:
# glm4-9b
full_attention_mask = self.get_masks(input_ids, full_attention_mask = self.get_masks(input_ids,
past_key_values, past_key_values,
padding_mask=attention_mask) padding_mask=attention_mask)
else:
full_attention_mask = self.get_masks(inputs_embeds,
past_key_values,
padding_mask=attention_mask)
# ipex-llm changes begin # ipex-llm changes begin
# 1. replace `rotary_pos_emb` with `inv_freq` and `position_ids` # 1. replace `rotary_pos_emb` with `inv_freq` and `position_ids`

View file

@ -14,16 +14,34 @@
# limitations under the License. # limitations under the License.
# #
# Some parts of this file is adapted from # Some parts of this file is adapted from
# https://github.com/dilab-zju/self-speculative-decoding/blob/main/decoding.py and # https://github.com/huggingface/transformers/blob/main/src/transformers/generation/utils.py and
# https://github.com/huggingface/transformers/blob/main/src/transformers/generation # https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/generation/utils.py
# /utils.py # which are licensed under Apache License 2.0:
# #
# Copyright 2020 The Google AI Language Team Authors, Facebook AI Research authors
# and The HuggingFace Inc. team.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# and https://github.com/dilab-zju/self-speculative-decoding/blob/main/decoding.py
import torch import torch
import time import time
import os import os
import copy import copy
import logging import logging
import inspect
import transformers import transformers
from packaging import version from packaging import version
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
@ -493,14 +511,17 @@ def _crop_past_key_values(self, past_key_values, new_cache_size, _enable_ipex=Fa
for k, v in past_key_values for k, v in past_key_values
] ]
elif self.config.model_type == "chatglm": elif self.config.model_type == "chatglm":
if self.config.num_layers == 40 and hasattr(self.config, 'rope_ratio'): if isinstance(self.config.eos_token_id, list) and \
not hasattr(self.transformer, "vision") and \
self.config.num_layers in [28, 40]:
# glm4 models
past_key_values = [ past_key_values = [
(k[:, :, :-(new_cache_size), :], (k[:, :, :-(new_cache_size), :],
v[:, :, :-(new_cache_size), :]) v[:, :, :-(new_cache_size), :])
for k, v in past_key_values for k, v in past_key_values
] ]
else: else:
# for chatglm, cache shape is [sl, bs, nh, hn] # chatglm2 & chatglm3, cache shape is [sl, bs, nh, hn]
past_key_values = [ past_key_values = [
(k[:-(new_cache_size), :, :, :], (k[:-(new_cache_size), :, :, :],
v[:-(new_cache_size), :, :, :]) v[:-(new_cache_size), :, :, :])
@ -631,6 +652,127 @@ def _prepare_generate_args(self, inputs, generation_config, streamer=None, **sam
return input_ids, generation_config, logits_processor, stopping_criteria, model_kwargs return input_ids, generation_config, logits_processor, stopping_criteria, model_kwargs
def _prepare_generate_args_4_45(self, inputs, generation_config, streamer=None, **kwargs):
# 1. Handle `generation_config` and kwargs that might update it,
# and validate the `.generate()` call
self._validate_model_class()
# Pull this out first, we only use it for stopping criteria
tokenizer = kwargs.pop("tokenizer", None)
generation_config, model_kwargs = self._prepare_generation_config(generation_config, **kwargs)
self._validate_model_kwargs(model_kwargs.copy())
# 2. Set generation parameters if not already defined
logits_processor = kwargs.pop("logits_processor", None)
stopping_criteria = kwargs.pop("stopping_criteria", None)
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = \
stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
accepts_attention_mask = \
"attention_mask" in set(inspect.signature(self.forward).parameters.keys())
requires_attention_mask = "encoder_outputs" not in model_kwargs
kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
# 3. Define model inputs
inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
inputs, generation_config.bos_token_id, model_kwargs
)
batch_size = inputs_tensor.shape[0]
device = inputs_tensor.device
self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=device)
# decoder-only models must use left-padding for batched generation.
from transformers.utils import is_torchdynamo_compiling
if not self.config.is_encoder_decoder and not is_torchdynamo_compiling():
# If `input_ids` was given, check if the last id in any sequence is `pad_token_id`
# Note: If using, `inputs_embeds` this check does not work
# because we want to be more hands-off.
if (
generation_config._pad_token_tensor is not None
and batch_size > 1
and len(inputs_tensor.shape) == 2
and torch.sum(inputs_tensor[:, -1] == generation_config._pad_token_tensor) > 0
):
logger.warning(
"A decoder-only architecture is being used, but right-padding was detected! "
"For correct generation results, please set `padding_side='left'` "
"when initializing the tokenizer."
)
else:
perf_mode = os.environ.get("IPEX_LLM_PERFORMANCE_MODE", None)
if perf_mode == "1":
error_str = "IPEX-LLM performance mode"
else:
error_str = "IPEX-LLM lookup or speculative generation"
invalidInputError(False, f"Encoder-decoder models are not supported now for {error_str}.")
# 4. Define other model kwargs
# decoder-only models with inputs_embeds forwarding must use caching
# (otherwise we can't detect whether we are generating the first new token or not,
# and we only want to use the embeddings for the first new token)
if not self.config.is_encoder_decoder and model_input_name == "inputs_embeds":
model_kwargs["use_cache"] = True
else:
model_kwargs["use_cache"] = generation_config.use_cache
if not kwargs_has_attention_mask and requires_attention_mask and accepts_attention_mask:
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
inputs_tensor, generation_config._pad_token_tensor, generation_config._eos_token_tensor
)
elif kwargs_has_attention_mask:
if model_input_name == "input_ids" and len(model_kwargs["attention_mask"].shape) > 2:
invalidInputError(False, "`attention_mask` passed to `generate` must be 2D.")
# 5. Prepare `input_ids` which will be used for auto-regressive generation
input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")
if generation_config.token_healing:
input_ids = self.heal_tokens(input_ids, tokenizer)
if streamer is not None:
streamer.put(input_ids.cpu())
# 6. Prepare `max_length` depending on other stopping criteria.
input_ids_length = input_ids.shape[-1]
# skip due to individul max_new_token dealing
# 7. Prepare the cache.
# skip
# 8. determine generation mode
# skip
if streamer is not None and (generation_config.num_beams > 1):
invalidInputError(
False,
"`streamer` cannot be used with beam search (yet!). "
"Make sure that `num_beams` is set to 1."
)
# 9. prepare logits processors and stopping criteria
prefix_allowed_tokens_fn = kwargs.pop("prefix_allowed_tokens_fn", None)
prepared_logits_processor = self._get_logits_processor(
generation_config=generation_config,
input_ids_seq_length=input_ids_length,
encoder_input_ids=inputs_tensor,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
logits_processor=logits_processor,
device=inputs_tensor.device,
model_kwargs=model_kwargs,
negative_prompt_ids=None,
negative_prompt_attention_mask=None,
)
prepared_stopping_criteria = self._get_stopping_criteria(
generation_config=generation_config,
stopping_criteria=stopping_criteria,
tokenizer=tokenizer,
**kwargs
)
return input_ids, generation_config, prepared_logits_processor, prepared_stopping_criteria,\
model_kwargs
def _non_cpu_ipex_verify(self, verify_input_ids, past_key_values, cur_attention_mask=None, def _non_cpu_ipex_verify(self, verify_input_ids, past_key_values, cur_attention_mask=None,
return_dict=True, use_cache=True): return_dict=True, use_cache=True):
forward_args = { forward_args = {
@ -643,6 +785,12 @@ def _non_cpu_ipex_verify(self, verify_input_ids, past_key_values, cur_attention_
forward_args["attention_mask"] = cur_attention_mask forward_args["attention_mask"] = cur_attention_mask
if self.config.model_type == "chatglm": if self.config.model_type == "chatglm":
if isinstance(self.config.eos_token_id, list) and not hasattr(self.transformer, "vision") \
and self.config.num_layers in [28, 40]:
# glm4 models
past_key_value_len = past_key_values[0][0].shape[2]
else:
# chatglm2 and chatglm3
past_key_value_len = past_key_values[0][0].shape[0] past_key_value_len = past_key_values[0][0].shape[0]
position_ids = torch.arange(verify_input_ids.shape[1], dtype=torch.long, position_ids = torch.arange(verify_input_ids.shape[1], dtype=torch.long,
device=verify_input_ids.device) device=verify_input_ids.device)