Support performance mode of GLM4 model (#12401)
* Initial support of prepare generation args for transformers 445 * Small fix to chatglm4 model optimization * Small fix * fix glm4 position id * fix glm4 error * Small change in conditon & fix based on comments * Style fixes --------- Co-authored-by: cyita <yitastudy@gmail.com>
This commit is contained in:
parent
d2c821d458
commit
a69395f31f
3 changed files with 177 additions and 16 deletions
|
|
@ -27,9 +27,11 @@ import time
|
||||||
import copy
|
import copy
|
||||||
import random
|
import random
|
||||||
import logging
|
import logging
|
||||||
|
import transformers
|
||||||
from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteriaList
|
from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteriaList
|
||||||
from ipex_llm.transformers.speculative import greedy, deepmind_sample, logits_to_probs,\
|
from ipex_llm.transformers.speculative import greedy, deepmind_sample, logits_to_probs,\
|
||||||
_crop_past_key_values, _prepare_generate_args, _non_cpu_ipex_verify, clear_benchmarks
|
_crop_past_key_values, _prepare_generate_args, _non_cpu_ipex_verify, clear_benchmarks,\
|
||||||
|
_prepare_generate_args_4_45
|
||||||
from ipex_llm.utils.common import invalidInputError
|
from ipex_llm.utils.common import invalidInputError
|
||||||
from ipex_llm.transformers.utils import get_xpu_device_type
|
from ipex_llm.transformers.utils import get_xpu_device_type
|
||||||
|
|
||||||
|
|
@ -278,16 +280,21 @@ def lookup_generate(self,
|
||||||
streamer: Optional["BaseStreamer"] = None,
|
streamer: Optional["BaseStreamer"] = None,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
**sampling_kwargs):
|
**sampling_kwargs):
|
||||||
|
from packaging import version
|
||||||
|
trans_version = transformers.__version__
|
||||||
|
|
||||||
|
if version.parse(trans_version) >= version.parse("4.45.0"):
|
||||||
|
input_ids, generation_config, logits_processor, stopping_criteria, \
|
||||||
|
model_kwargs = _prepare_generate_args_4_45(self, inputs, generation_config,
|
||||||
|
streamer, **sampling_kwargs)
|
||||||
|
else:
|
||||||
input_ids, generation_config, logits_processor, stopping_criteria, \
|
input_ids, generation_config, logits_processor, stopping_criteria, \
|
||||||
model_kwargs = _prepare_generate_args(self, inputs, generation_config,
|
model_kwargs = _prepare_generate_args(self, inputs, generation_config,
|
||||||
**sampling_kwargs)
|
streamer, **sampling_kwargs)
|
||||||
|
|
||||||
invalidInputError(input_ids.shape[0] == 1,
|
invalidInputError(input_ids.shape[0] == 1,
|
||||||
"Prompt lookup is currently not supported with batch inference.")
|
"Prompt lookup is currently not supported with batch inference.")
|
||||||
|
|
||||||
if streamer is not None:
|
|
||||||
streamer.put(input_ids.cpu())
|
|
||||||
|
|
||||||
device_name = get_xpu_device_type(input_ids)
|
device_name = get_xpu_device_type(input_ids)
|
||||||
|
|
||||||
candidates_generator = PromptLookupCandidateGenerator(
|
candidates_generator = PromptLookupCandidateGenerator(
|
||||||
|
|
|
||||||
|
|
@ -76,9 +76,15 @@ def chatglm4_model_forward(
|
||||||
if full_attention_mask is None:
|
if full_attention_mask is None:
|
||||||
if (attention_mask is not None and not attention_mask.all()) or\
|
if (attention_mask is not None and not attention_mask.all()) or\
|
||||||
(past_key_values and seq_length != 1):
|
(past_key_values and seq_length != 1):
|
||||||
|
if self.config.hidden_size == 4096:
|
||||||
|
# glm4-9b
|
||||||
full_attention_mask = self.get_masks(input_ids,
|
full_attention_mask = self.get_masks(input_ids,
|
||||||
past_key_values,
|
past_key_values,
|
||||||
padding_mask=attention_mask)
|
padding_mask=attention_mask)
|
||||||
|
else:
|
||||||
|
full_attention_mask = self.get_masks(inputs_embeds,
|
||||||
|
past_key_values,
|
||||||
|
padding_mask=attention_mask)
|
||||||
|
|
||||||
# ipex-llm changes begin
|
# ipex-llm changes begin
|
||||||
# 1. replace `rotary_pos_emb` with `inv_freq` and `position_ids`
|
# 1. replace `rotary_pos_emb` with `inv_freq` and `position_ids`
|
||||||
|
|
|
||||||
|
|
@ -14,16 +14,34 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
# Some parts of this file is adapted from
|
# Some parts of this file is adapted from
|
||||||
# https://github.com/dilab-zju/self-speculative-decoding/blob/main/decoding.py and
|
# https://github.com/huggingface/transformers/blob/main/src/transformers/generation/utils.py and
|
||||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/generation
|
# https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/generation/utils.py
|
||||||
# /utils.py
|
# which are licensed under Apache License 2.0:
|
||||||
#
|
#
|
||||||
|
# Copyright 2020 The Google AI Language Team Authors, Facebook AI Research authors
|
||||||
|
# and The HuggingFace Inc. team.
|
||||||
|
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
# and https://github.com/dilab-zju/self-speculative-decoding/blob/main/decoding.py
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import time
|
import time
|
||||||
import os
|
import os
|
||||||
import copy
|
import copy
|
||||||
import logging
|
import logging
|
||||||
|
import inspect
|
||||||
import transformers
|
import transformers
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
|
||||||
|
|
@ -493,14 +511,17 @@ def _crop_past_key_values(self, past_key_values, new_cache_size, _enable_ipex=Fa
|
||||||
for k, v in past_key_values
|
for k, v in past_key_values
|
||||||
]
|
]
|
||||||
elif self.config.model_type == "chatglm":
|
elif self.config.model_type == "chatglm":
|
||||||
if self.config.num_layers == 40 and hasattr(self.config, 'rope_ratio'):
|
if isinstance(self.config.eos_token_id, list) and \
|
||||||
|
not hasattr(self.transformer, "vision") and \
|
||||||
|
self.config.num_layers in [28, 40]:
|
||||||
|
# glm4 models
|
||||||
past_key_values = [
|
past_key_values = [
|
||||||
(k[:, :, :-(new_cache_size), :],
|
(k[:, :, :-(new_cache_size), :],
|
||||||
v[:, :, :-(new_cache_size), :])
|
v[:, :, :-(new_cache_size), :])
|
||||||
for k, v in past_key_values
|
for k, v in past_key_values
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
# for chatglm, cache shape is [sl, bs, nh, hn]
|
# chatglm2 & chatglm3, cache shape is [sl, bs, nh, hn]
|
||||||
past_key_values = [
|
past_key_values = [
|
||||||
(k[:-(new_cache_size), :, :, :],
|
(k[:-(new_cache_size), :, :, :],
|
||||||
v[:-(new_cache_size), :, :, :])
|
v[:-(new_cache_size), :, :, :])
|
||||||
|
|
@ -631,6 +652,127 @@ def _prepare_generate_args(self, inputs, generation_config, streamer=None, **sam
|
||||||
return input_ids, generation_config, logits_processor, stopping_criteria, model_kwargs
|
return input_ids, generation_config, logits_processor, stopping_criteria, model_kwargs
|
||||||
|
|
||||||
|
|
||||||
|
def _prepare_generate_args_4_45(self, inputs, generation_config, streamer=None, **kwargs):
|
||||||
|
# 1. Handle `generation_config` and kwargs that might update it,
|
||||||
|
# and validate the `.generate()` call
|
||||||
|
self._validate_model_class()
|
||||||
|
# Pull this out first, we only use it for stopping criteria
|
||||||
|
tokenizer = kwargs.pop("tokenizer", None)
|
||||||
|
generation_config, model_kwargs = self._prepare_generation_config(generation_config, **kwargs)
|
||||||
|
self._validate_model_kwargs(model_kwargs.copy())
|
||||||
|
|
||||||
|
# 2. Set generation parameters if not already defined
|
||||||
|
logits_processor = kwargs.pop("logits_processor", None)
|
||||||
|
stopping_criteria = kwargs.pop("stopping_criteria", None)
|
||||||
|
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
||||||
|
stopping_criteria = \
|
||||||
|
stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
||||||
|
|
||||||
|
accepts_attention_mask = \
|
||||||
|
"attention_mask" in set(inspect.signature(self.forward).parameters.keys())
|
||||||
|
requires_attention_mask = "encoder_outputs" not in model_kwargs
|
||||||
|
kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
|
||||||
|
|
||||||
|
# 3. Define model inputs
|
||||||
|
inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
|
||||||
|
inputs, generation_config.bos_token_id, model_kwargs
|
||||||
|
)
|
||||||
|
batch_size = inputs_tensor.shape[0]
|
||||||
|
|
||||||
|
device = inputs_tensor.device
|
||||||
|
self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=device)
|
||||||
|
|
||||||
|
# decoder-only models must use left-padding for batched generation.
|
||||||
|
from transformers.utils import is_torchdynamo_compiling
|
||||||
|
if not self.config.is_encoder_decoder and not is_torchdynamo_compiling():
|
||||||
|
# If `input_ids` was given, check if the last id in any sequence is `pad_token_id`
|
||||||
|
# Note: If using, `inputs_embeds` this check does not work
|
||||||
|
# because we want to be more hands-off.
|
||||||
|
if (
|
||||||
|
generation_config._pad_token_tensor is not None
|
||||||
|
and batch_size > 1
|
||||||
|
and len(inputs_tensor.shape) == 2
|
||||||
|
and torch.sum(inputs_tensor[:, -1] == generation_config._pad_token_tensor) > 0
|
||||||
|
):
|
||||||
|
logger.warning(
|
||||||
|
"A decoder-only architecture is being used, but right-padding was detected! "
|
||||||
|
"For correct generation results, please set `padding_side='left'` "
|
||||||
|
"when initializing the tokenizer."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
perf_mode = os.environ.get("IPEX_LLM_PERFORMANCE_MODE", None)
|
||||||
|
if perf_mode == "1":
|
||||||
|
error_str = "IPEX-LLM performance mode"
|
||||||
|
else:
|
||||||
|
error_str = "IPEX-LLM lookup or speculative generation"
|
||||||
|
invalidInputError(False, f"Encoder-decoder models are not supported now for {error_str}.")
|
||||||
|
|
||||||
|
# 4. Define other model kwargs
|
||||||
|
# decoder-only models with inputs_embeds forwarding must use caching
|
||||||
|
# (otherwise we can't detect whether we are generating the first new token or not,
|
||||||
|
# and we only want to use the embeddings for the first new token)
|
||||||
|
if not self.config.is_encoder_decoder and model_input_name == "inputs_embeds":
|
||||||
|
model_kwargs["use_cache"] = True
|
||||||
|
else:
|
||||||
|
model_kwargs["use_cache"] = generation_config.use_cache
|
||||||
|
|
||||||
|
if not kwargs_has_attention_mask and requires_attention_mask and accepts_attention_mask:
|
||||||
|
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
|
||||||
|
inputs_tensor, generation_config._pad_token_tensor, generation_config._eos_token_tensor
|
||||||
|
)
|
||||||
|
elif kwargs_has_attention_mask:
|
||||||
|
if model_input_name == "input_ids" and len(model_kwargs["attention_mask"].shape) > 2:
|
||||||
|
invalidInputError(False, "`attention_mask` passed to `generate` must be 2D.")
|
||||||
|
|
||||||
|
# 5. Prepare `input_ids` which will be used for auto-regressive generation
|
||||||
|
input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")
|
||||||
|
|
||||||
|
if generation_config.token_healing:
|
||||||
|
input_ids = self.heal_tokens(input_ids, tokenizer)
|
||||||
|
|
||||||
|
if streamer is not None:
|
||||||
|
streamer.put(input_ids.cpu())
|
||||||
|
|
||||||
|
# 6. Prepare `max_length` depending on other stopping criteria.
|
||||||
|
input_ids_length = input_ids.shape[-1]
|
||||||
|
# skip due to individul max_new_token dealing
|
||||||
|
|
||||||
|
# 7. Prepare the cache.
|
||||||
|
# skip
|
||||||
|
|
||||||
|
# 8. determine generation mode
|
||||||
|
# skip
|
||||||
|
if streamer is not None and (generation_config.num_beams > 1):
|
||||||
|
invalidInputError(
|
||||||
|
False,
|
||||||
|
"`streamer` cannot be used with beam search (yet!). "
|
||||||
|
"Make sure that `num_beams` is set to 1."
|
||||||
|
)
|
||||||
|
|
||||||
|
# 9. prepare logits processors and stopping criteria
|
||||||
|
prefix_allowed_tokens_fn = kwargs.pop("prefix_allowed_tokens_fn", None)
|
||||||
|
prepared_logits_processor = self._get_logits_processor(
|
||||||
|
generation_config=generation_config,
|
||||||
|
input_ids_seq_length=input_ids_length,
|
||||||
|
encoder_input_ids=inputs_tensor,
|
||||||
|
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
|
||||||
|
logits_processor=logits_processor,
|
||||||
|
device=inputs_tensor.device,
|
||||||
|
model_kwargs=model_kwargs,
|
||||||
|
negative_prompt_ids=None,
|
||||||
|
negative_prompt_attention_mask=None,
|
||||||
|
)
|
||||||
|
prepared_stopping_criteria = self._get_stopping_criteria(
|
||||||
|
generation_config=generation_config,
|
||||||
|
stopping_criteria=stopping_criteria,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
return input_ids, generation_config, prepared_logits_processor, prepared_stopping_criteria,\
|
||||||
|
model_kwargs
|
||||||
|
|
||||||
|
|
||||||
def _non_cpu_ipex_verify(self, verify_input_ids, past_key_values, cur_attention_mask=None,
|
def _non_cpu_ipex_verify(self, verify_input_ids, past_key_values, cur_attention_mask=None,
|
||||||
return_dict=True, use_cache=True):
|
return_dict=True, use_cache=True):
|
||||||
forward_args = {
|
forward_args = {
|
||||||
|
|
@ -643,6 +785,12 @@ def _non_cpu_ipex_verify(self, verify_input_ids, past_key_values, cur_attention_
|
||||||
forward_args["attention_mask"] = cur_attention_mask
|
forward_args["attention_mask"] = cur_attention_mask
|
||||||
|
|
||||||
if self.config.model_type == "chatglm":
|
if self.config.model_type == "chatglm":
|
||||||
|
if isinstance(self.config.eos_token_id, list) and not hasattr(self.transformer, "vision") \
|
||||||
|
and self.config.num_layers in [28, 40]:
|
||||||
|
# glm4 models
|
||||||
|
past_key_value_len = past_key_values[0][0].shape[2]
|
||||||
|
else:
|
||||||
|
# chatglm2 and chatglm3
|
||||||
past_key_value_len = past_key_values[0][0].shape[0]
|
past_key_value_len = past_key_values[0][0].shape[0]
|
||||||
position_ids = torch.arange(verify_input_ids.shape[1], dtype=torch.long,
|
position_ids = torch.arange(verify_input_ids.shape[1], dtype=torch.long,
|
||||||
device=verify_input_ids.device)
|
device=verify_input_ids.device)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue