Support performance mode of GLM4 model (#12401)
* Initial support of prepare generation args for transformers 445 * Small fix to chatglm4 model optimization * Small fix * fix glm4 position id * fix glm4 error * Small change in conditon & fix based on comments * Style fixes --------- Co-authored-by: cyita <yitastudy@gmail.com>
This commit is contained in:
parent
d2c821d458
commit
a69395f31f
3 changed files with 177 additions and 16 deletions
|
|
@ -27,9 +27,11 @@ import time
|
|||
import copy
|
||||
import random
|
||||
import logging
|
||||
import transformers
|
||||
from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteriaList
|
||||
from ipex_llm.transformers.speculative import greedy, deepmind_sample, logits_to_probs,\
|
||||
_crop_past_key_values, _prepare_generate_args, _non_cpu_ipex_verify, clear_benchmarks
|
||||
_crop_past_key_values, _prepare_generate_args, _non_cpu_ipex_verify, clear_benchmarks,\
|
||||
_prepare_generate_args_4_45
|
||||
from ipex_llm.utils.common import invalidInputError
|
||||
from ipex_llm.transformers.utils import get_xpu_device_type
|
||||
|
||||
|
|
@ -278,16 +280,21 @@ def lookup_generate(self,
|
|||
streamer: Optional["BaseStreamer"] = None,
|
||||
attention_mask=None,
|
||||
**sampling_kwargs):
|
||||
from packaging import version
|
||||
trans_version = transformers.__version__
|
||||
|
||||
if version.parse(trans_version) >= version.parse("4.45.0"):
|
||||
input_ids, generation_config, logits_processor, stopping_criteria, \
|
||||
model_kwargs = _prepare_generate_args_4_45(self, inputs, generation_config,
|
||||
streamer, **sampling_kwargs)
|
||||
else:
|
||||
input_ids, generation_config, logits_processor, stopping_criteria, \
|
||||
model_kwargs = _prepare_generate_args(self, inputs, generation_config,
|
||||
**sampling_kwargs)
|
||||
streamer, **sampling_kwargs)
|
||||
|
||||
invalidInputError(input_ids.shape[0] == 1,
|
||||
"Prompt lookup is currently not supported with batch inference.")
|
||||
|
||||
if streamer is not None:
|
||||
streamer.put(input_ids.cpu())
|
||||
|
||||
device_name = get_xpu_device_type(input_ids)
|
||||
|
||||
candidates_generator = PromptLookupCandidateGenerator(
|
||||
|
|
|
|||
|
|
@ -76,9 +76,15 @@ def chatglm4_model_forward(
|
|||
if full_attention_mask is None:
|
||||
if (attention_mask is not None and not attention_mask.all()) or\
|
||||
(past_key_values and seq_length != 1):
|
||||
if self.config.hidden_size == 4096:
|
||||
# glm4-9b
|
||||
full_attention_mask = self.get_masks(input_ids,
|
||||
past_key_values,
|
||||
padding_mask=attention_mask)
|
||||
else:
|
||||
full_attention_mask = self.get_masks(inputs_embeds,
|
||||
past_key_values,
|
||||
padding_mask=attention_mask)
|
||||
|
||||
# ipex-llm changes begin
|
||||
# 1. replace `rotary_pos_emb` with `inv_freq` and `position_ids`
|
||||
|
|
|
|||
|
|
@ -14,16 +14,34 @@
|
|||
# limitations under the License.
|
||||
#
|
||||
# Some parts of this file is adapted from
|
||||
# https://github.com/dilab-zju/self-speculative-decoding/blob/main/decoding.py and
|
||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/generation
|
||||
# /utils.py
|
||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/generation/utils.py and
|
||||
# https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/generation/utils.py
|
||||
# which are licensed under Apache License 2.0:
|
||||
#
|
||||
# Copyright 2020 The Google AI Language Team Authors, Facebook AI Research authors
|
||||
# and The HuggingFace Inc. team.
|
||||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# and https://github.com/dilab-zju/self-speculative-decoding/blob/main/decoding.py
|
||||
|
||||
import torch
|
||||
import time
|
||||
import os
|
||||
import copy
|
||||
import logging
|
||||
import inspect
|
||||
import transformers
|
||||
from packaging import version
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
|
@ -493,14 +511,17 @@ def _crop_past_key_values(self, past_key_values, new_cache_size, _enable_ipex=Fa
|
|||
for k, v in past_key_values
|
||||
]
|
||||
elif self.config.model_type == "chatglm":
|
||||
if self.config.num_layers == 40 and hasattr(self.config, 'rope_ratio'):
|
||||
if isinstance(self.config.eos_token_id, list) and \
|
||||
not hasattr(self.transformer, "vision") and \
|
||||
self.config.num_layers in [28, 40]:
|
||||
# glm4 models
|
||||
past_key_values = [
|
||||
(k[:, :, :-(new_cache_size), :],
|
||||
v[:, :, :-(new_cache_size), :])
|
||||
for k, v in past_key_values
|
||||
]
|
||||
else:
|
||||
# for chatglm, cache shape is [sl, bs, nh, hn]
|
||||
# chatglm2 & chatglm3, cache shape is [sl, bs, nh, hn]
|
||||
past_key_values = [
|
||||
(k[:-(new_cache_size), :, :, :],
|
||||
v[:-(new_cache_size), :, :, :])
|
||||
|
|
@ -631,6 +652,127 @@ def _prepare_generate_args(self, inputs, generation_config, streamer=None, **sam
|
|||
return input_ids, generation_config, logits_processor, stopping_criteria, model_kwargs
|
||||
|
||||
|
||||
def _prepare_generate_args_4_45(self, inputs, generation_config, streamer=None, **kwargs):
|
||||
# 1. Handle `generation_config` and kwargs that might update it,
|
||||
# and validate the `.generate()` call
|
||||
self._validate_model_class()
|
||||
# Pull this out first, we only use it for stopping criteria
|
||||
tokenizer = kwargs.pop("tokenizer", None)
|
||||
generation_config, model_kwargs = self._prepare_generation_config(generation_config, **kwargs)
|
||||
self._validate_model_kwargs(model_kwargs.copy())
|
||||
|
||||
# 2. Set generation parameters if not already defined
|
||||
logits_processor = kwargs.pop("logits_processor", None)
|
||||
stopping_criteria = kwargs.pop("stopping_criteria", None)
|
||||
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
||||
stopping_criteria = \
|
||||
stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
||||
|
||||
accepts_attention_mask = \
|
||||
"attention_mask" in set(inspect.signature(self.forward).parameters.keys())
|
||||
requires_attention_mask = "encoder_outputs" not in model_kwargs
|
||||
kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
|
||||
|
||||
# 3. Define model inputs
|
||||
inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
|
||||
inputs, generation_config.bos_token_id, model_kwargs
|
||||
)
|
||||
batch_size = inputs_tensor.shape[0]
|
||||
|
||||
device = inputs_tensor.device
|
||||
self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=device)
|
||||
|
||||
# decoder-only models must use left-padding for batched generation.
|
||||
from transformers.utils import is_torchdynamo_compiling
|
||||
if not self.config.is_encoder_decoder and not is_torchdynamo_compiling():
|
||||
# If `input_ids` was given, check if the last id in any sequence is `pad_token_id`
|
||||
# Note: If using, `inputs_embeds` this check does not work
|
||||
# because we want to be more hands-off.
|
||||
if (
|
||||
generation_config._pad_token_tensor is not None
|
||||
and batch_size > 1
|
||||
and len(inputs_tensor.shape) == 2
|
||||
and torch.sum(inputs_tensor[:, -1] == generation_config._pad_token_tensor) > 0
|
||||
):
|
||||
logger.warning(
|
||||
"A decoder-only architecture is being used, but right-padding was detected! "
|
||||
"For correct generation results, please set `padding_side='left'` "
|
||||
"when initializing the tokenizer."
|
||||
)
|
||||
else:
|
||||
perf_mode = os.environ.get("IPEX_LLM_PERFORMANCE_MODE", None)
|
||||
if perf_mode == "1":
|
||||
error_str = "IPEX-LLM performance mode"
|
||||
else:
|
||||
error_str = "IPEX-LLM lookup or speculative generation"
|
||||
invalidInputError(False, f"Encoder-decoder models are not supported now for {error_str}.")
|
||||
|
||||
# 4. Define other model kwargs
|
||||
# decoder-only models with inputs_embeds forwarding must use caching
|
||||
# (otherwise we can't detect whether we are generating the first new token or not,
|
||||
# and we only want to use the embeddings for the first new token)
|
||||
if not self.config.is_encoder_decoder and model_input_name == "inputs_embeds":
|
||||
model_kwargs["use_cache"] = True
|
||||
else:
|
||||
model_kwargs["use_cache"] = generation_config.use_cache
|
||||
|
||||
if not kwargs_has_attention_mask and requires_attention_mask and accepts_attention_mask:
|
||||
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
|
||||
inputs_tensor, generation_config._pad_token_tensor, generation_config._eos_token_tensor
|
||||
)
|
||||
elif kwargs_has_attention_mask:
|
||||
if model_input_name == "input_ids" and len(model_kwargs["attention_mask"].shape) > 2:
|
||||
invalidInputError(False, "`attention_mask` passed to `generate` must be 2D.")
|
||||
|
||||
# 5. Prepare `input_ids` which will be used for auto-regressive generation
|
||||
input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")
|
||||
|
||||
if generation_config.token_healing:
|
||||
input_ids = self.heal_tokens(input_ids, tokenizer)
|
||||
|
||||
if streamer is not None:
|
||||
streamer.put(input_ids.cpu())
|
||||
|
||||
# 6. Prepare `max_length` depending on other stopping criteria.
|
||||
input_ids_length = input_ids.shape[-1]
|
||||
# skip due to individul max_new_token dealing
|
||||
|
||||
# 7. Prepare the cache.
|
||||
# skip
|
||||
|
||||
# 8. determine generation mode
|
||||
# skip
|
||||
if streamer is not None and (generation_config.num_beams > 1):
|
||||
invalidInputError(
|
||||
False,
|
||||
"`streamer` cannot be used with beam search (yet!). "
|
||||
"Make sure that `num_beams` is set to 1."
|
||||
)
|
||||
|
||||
# 9. prepare logits processors and stopping criteria
|
||||
prefix_allowed_tokens_fn = kwargs.pop("prefix_allowed_tokens_fn", None)
|
||||
prepared_logits_processor = self._get_logits_processor(
|
||||
generation_config=generation_config,
|
||||
input_ids_seq_length=input_ids_length,
|
||||
encoder_input_ids=inputs_tensor,
|
||||
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
|
||||
logits_processor=logits_processor,
|
||||
device=inputs_tensor.device,
|
||||
model_kwargs=model_kwargs,
|
||||
negative_prompt_ids=None,
|
||||
negative_prompt_attention_mask=None,
|
||||
)
|
||||
prepared_stopping_criteria = self._get_stopping_criteria(
|
||||
generation_config=generation_config,
|
||||
stopping_criteria=stopping_criteria,
|
||||
tokenizer=tokenizer,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
return input_ids, generation_config, prepared_logits_processor, prepared_stopping_criteria,\
|
||||
model_kwargs
|
||||
|
||||
|
||||
def _non_cpu_ipex_verify(self, verify_input_ids, past_key_values, cur_attention_mask=None,
|
||||
return_dict=True, use_cache=True):
|
||||
forward_args = {
|
||||
|
|
@ -643,6 +785,12 @@ def _non_cpu_ipex_verify(self, verify_input_ids, past_key_values, cur_attention_
|
|||
forward_args["attention_mask"] = cur_attention_mask
|
||||
|
||||
if self.config.model_type == "chatglm":
|
||||
if isinstance(self.config.eos_token_id, list) and not hasattr(self.transformer, "vision") \
|
||||
and self.config.num_layers in [28, 40]:
|
||||
# glm4 models
|
||||
past_key_value_len = past_key_values[0][0].shape[2]
|
||||
else:
|
||||
# chatglm2 and chatglm3
|
||||
past_key_value_len = past_key_values[0][0].shape[0]
|
||||
position_ids = torch.arange(verify_input_ids.shape[1], dtype=torch.long,
|
||||
device=verify_input_ids.device)
|
||||
|
|
|
|||
Loading…
Reference in a new issue