Support finishing PP inference once eos_token_id is found (#11336)

This commit is contained in:
binbin Deng 2024-06-18 09:55:40 +08:00 committed by GitHub
parent de4bb97b4f
commit e50c890e1f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 65 additions and 2 deletions

View file

@ -11,6 +11,7 @@ To run this example with IPEX-LLM on Intel GPUs, we have some recommended requir
- [meta-llama/Meta-Llama-3-8B-Instruct](./run_llama_arc_2_card.sh) - [meta-llama/Meta-Llama-3-8B-Instruct](./run_llama_arc_2_card.sh)
- [Qwen/Qwen1.5-7B-Chat](./run_qwen1.5_arc_2_card.sh) - [Qwen/Qwen1.5-7B-Chat](./run_qwen1.5_arc_2_card.sh)
- [Qwen/Qwen1.5-14B-Chat](./run_qwen1.5_arc_2_card.sh) - [Qwen/Qwen1.5-14B-Chat](./run_qwen1.5_arc_2_card.sh)
- [Qwen/Qwen1.5-32B-Chat](./run_qwen1.5_arc_2_card.sh)
- [baichuan-inc/Baichuan2-7B-Chat](./run_baichuan2_arc_2_card.sh) - [baichuan-inc/Baichuan2-7B-Chat](./run_baichuan2_arc_2_card.sh)
- [baichuan-inc/Baichuan2-13B-Chat](./run_baichuan2_arc_2_card.sh) - [baichuan-inc/Baichuan2-13B-Chat](./run_baichuan2_arc_2_card.sh)
- [microsoft/Phi-3-mini-4k-instruct](./run_phi3_arc_2_card.sh) - [microsoft/Phi-3-mini-4k-instruct](./run_phi3_arc_2_card.sh)
@ -57,7 +58,7 @@ bash run_llama_arc_2_card.sh
<details> <details>
<summary> Show Qwen1.5 example </summary> <summary> Show Qwen1.5 example </summary>
#### Run Qwen1.5-7B-Chat / Qwen1.5-14B-Chat on two Intel Arc A770 #### Run Qwen1.5-7B-Chat / Qwen1.5-14B-Chat / Qwen1.5-32B-Chat on two Intel Arc A770
You could specify `--repo-id-or-model-path` in the test script to be the huggingface repo id for Qwen1.5 to be downloaded, or the path to the huggingface checkpoint folder. Besides, you could change `NUM_GPUS` to the number of GPUs you have on your machine. You could specify `--repo-id-or-model-path` in the test script to be the huggingface repo id for Qwen1.5 to be downloaded, or the path to the huggingface checkpoint folder. Besides, you could change `NUM_GPUS` to the number of GPUs you have on your machine.

View file

@ -46,6 +46,7 @@ if __name__ == '__main__':
optimize_model=True, optimize_model=True,
trust_remote_code=True, trust_remote_code=True,
use_cache=True, use_cache=True,
torch_dtype=torch.float16,
pipeline_parallel_stages=args.gpu_num) pipeline_parallel_stages=args.gpu_num)
# Load tokenizer # Load tokenizer

View file

@ -34,3 +34,7 @@ CCL_ZE_IPC_EXCHANGE=sockets torchrun --standalone --nnodes=1 --nproc-per-node $N
# # To run Qwen1.5-14B-Chat # # To run Qwen1.5-14B-Chat
# CCL_ZE_IPC_EXCHANGE=sockets torchrun --standalone --nnodes=1 --nproc-per-node $NUM_GPUS \ # CCL_ZE_IPC_EXCHANGE=sockets torchrun --standalone --nnodes=1 --nproc-per-node $NUM_GPUS \
# generate.py --repo-id-or-model-path 'Qwen/Qwen1.5-14B-Chat' --gpu-num $NUM_GPUS # generate.py --repo-id-or-model-path 'Qwen/Qwen1.5-14B-Chat' --gpu-num $NUM_GPUS
# # To run Qwen1.5-32B-Chat
# CCL_ZE_IPC_EXCHANGE=sockets torchrun --standalone --nnodes=1 --nproc-per-node $NUM_GPUS \
# generate.py --repo-id-or-model-path 'Qwen/Qwen1.5-32B-Chat' --gpu-num $NUM_GPUS

View file

@ -25,6 +25,9 @@ import time
import numpy as np import numpy as np
from typing import Callable, List, Optional from typing import Callable, List, Optional
from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteriaList from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteriaList
from ipex_llm.utils.common import invalidInputError
import logging
logger = logging.getLogger(__name__)
# patch GenerationMixin.generate # patch GenerationMixin.generate
from transformers import GenerationMixin from transformers import GenerationMixin
@ -118,12 +121,34 @@ def generate(
**kwargs, **kwargs,
): ):
if hasattr(self, 'pipeline_parallel_stages') and self.pipeline_parallel_stages > 1: if hasattr(self, 'pipeline_parallel_stages') and self.pipeline_parallel_stages > 1:
# priority: `generation_config` argument > `model.generation_config`
if generation_config is None:
if (
self.generation_config._from_model_config
and self.generation_config._original_object_hash == hash(self.generation_config)
and self.config._has_non_default_generation_parameters()
):
new_generation_config = GenerationConfig.from_model_config(self.config)
if new_generation_config != self.generation_config:
self.generation_config = new_generation_config
generation_config = self.generation_config
if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:
eos_token_id = generation_config.eos_token_id
if isinstance(eos_token_id, list):
eos_token_id = eos_token_id[0]
logger.warning("Setting `pad_token_id` to `eos_token_id`: "
f"{eos_token_id} for open-end generation.")
generation_config.pad_token_id = eos_token_id
if generation_config is not None and generation_config.max_new_tokens is not None: if generation_config is not None and generation_config.max_new_tokens is not None:
max_new_tokens = generation_config.max_new_tokens max_new_tokens = generation_config.max_new_tokens
else: else:
max_new_tokens = kwargs.get("max_new_tokens", None) max_new_tokens = kwargs.get("max_new_tokens", None)
return self.pipeline_parallel_generate(inputs=inputs, return self.pipeline_parallel_generate(inputs=inputs,
max_new_tokens=max_new_tokens,) max_new_tokens=max_new_tokens,
generation_config=generation_config,)
return original_generate(self, return original_generate(self,
inputs=inputs, inputs=inputs,
@ -143,6 +168,7 @@ GenerationMixin.generate = generate
def pipeline_parallel_generate(self, def pipeline_parallel_generate(self,
inputs: Optional[torch.Tensor] = None, inputs: Optional[torch.Tensor] = None,
max_new_tokens: int = 32, max_new_tokens: int = 32,
generation_config: Optional[GenerationConfig] = None,
**kwargs): **kwargs):
local_rank = dist.get_rank() local_rank = dist.get_rank()
pre_rank = (local_rank - 1) % self.pipeline_parallel_stages pre_rank = (local_rank - 1) % self.pipeline_parallel_stages
@ -154,12 +180,22 @@ def pipeline_parallel_generate(self,
self.first_token_time = 0 self.first_token_time = 0
self.next_token_time = [] self.next_token_time = []
pad_token_id = generation_config.pad_token_id
eos_token_id = generation_config.eos_token_id
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id_tensor = torch.tensor(eos_token_id).to(inputs.device) \
if eos_token_id is not None else None
_input_ids = None _input_ids = None
_past_key_values = None _past_key_values = None
bs = inputs.shape[0] bs = inputs.shape[0]
output_ids = inputs.clone() output_ids = inputs.clone()
step = 0 step = 0
# keep track of which sequences are already finished
unfinished_sequences = torch.ones(inputs.shape[0], dtype=torch.long, device=inputs.device)
this_peer_finished = False
while True: while True:
if step >= max_new_tokens: if step >= max_new_tokens:
break break
@ -190,6 +226,14 @@ def pipeline_parallel_generate(self,
_input_ids = next_ids _input_ids = next_ids
output_ids = torch.cat([output_ids, next_ids], dim=-1) output_ids = torch.cat([output_ids, next_ids], dim=-1)
# finished sentences should have their next token be a padding token
next_ids = next_ids.squeeze()
if eos_token_id is not None:
if pad_token_id is None:
invalidInputError(False, "If `eos_token_id` is defined, "
"make sure that `pad_token_id` is defined.")
next_ids = next_ids * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
if isinstance(outputs.past_key_values, tuple) and local_rank != 0: if isinstance(outputs.past_key_values, tuple) and local_rank != 0:
value_placeholder = torch.empty_like((outputs.past_key_values)[-1][0]) value_placeholder = torch.empty_like((outputs.past_key_values)[-1][0])
past_key_values_placeholder = tuple( past_key_values_placeholder = tuple(
@ -204,6 +248,19 @@ def pipeline_parallel_generate(self,
self.first_token_time = toc - tic self.first_token_time = toc - tic
else: else:
self.next_token_time.append(toc - tic) self.next_token_time.append(toc - tic)
# if eos_token was found in one sentence, set sentence to finished
if eos_token_id_tensor is not None:
unfinished_sequences = unfinished_sequences.mul(
next_ids.tile(eos_token_id_tensor.shape[0], 1)
.ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
)
# stop when each sentence is finished
if unfinished_sequences.max() == 0:
this_peer_finished = True
if this_peer_finished:
break
step += 1 step += 1
if self.device.type == 'xpu': if self.device.type == 'xpu':
torch.xpu.synchronize() torch.xpu.synchronize()