Refactor pipeline parallel multi-stage implementation (#11286)

This commit is contained in:
binbin Deng 2024-06-13 10:00:23 +08:00 committed by GitHub
parent 14b1e6b699
commit 220151e2a1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 271 additions and 124 deletions

View file

@ -5,90 +5,48 @@ This example demonstrates how to run IPEX-LLM optimized low-bit model vertically
## Requirements ## Requirements
To run this example with IPEX-LLM on Intel GPUs, we have some recommended requirements for your machine, please refer to [here](../README.md#recommended-requirements) for more information. For this particular example, you will need at least two GPUs on your machine. To run this example with IPEX-LLM on Intel GPUs, we have some recommended requirements for your machine, please refer to [here](../README.md#recommended-requirements) for more information. For this particular example, you will need at least two GPUs on your machine.
> [!NOTE] ## Verified Models
> To run IPEX-LLM on multiple Intel GPUs in pipeline parallel fashion, you will need to install **Intel® oneAPI Base Toolkit 2024.1**, which could be done through an offline installer: - [Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf)
> ```bash - [Llama-2-13b-chat-hf](https://huggingface.co/meta-llama/Llama-2-13b-chat-hf)
> wget https://registrationcenter-download.intel.com/akdlm/IRC_NAS/fdc7a2bc-b7a8-47eb-8876-de6201297144/l_BaseKit_p_2024.1.0.596_offline.sh - [Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct)
>
> sudo sh ./l_BaseKit_p_2024.1.0.596_offline.sh
> ```
## Example: Run pipeline parallel inference on multiple GPUs ## Example: Run pipeline parallel inference on multiple GPUs
### 0. Prerequisites
Please visit the [Install IPEX-LLM on Linux with Intel GPU](https://ipex-llm.readthedocs.io/en/latest/doc/LLM/Quickstart/install_linux_gpu.html), follow [Install Intel GPU Driver](https://ipex-llm.readthedocs.io/en/latest/doc/LLM/Quickstart/install_linux_gpu.html#install-intel-gpu-driver) and [Install oneAPI](https://ipex-llm.readthedocs.io/en/latest/doc/LLM/Quickstart/install_linux_gpu.html#install-oneapi) to install GPU driver and Intel® oneAPI Base Toolkit 2024.0.
### 1. Installation ### 1. Installation
```bash ```bash
conda create -n llm python=3.11 conda create -n llm python=3.11
conda activate llm conda activate llm
# below command will install intel_extension_for_pytorch==2.1.10+xpu as default
pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
pip install torch==2.1.0.post2 torchvision==0.16.0.post2 torchaudio==2.1.0.post2 intel-extension-for-pytorch==2.1.30+xpu oneccl_bind_pt==2.1.300+xpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ pip install oneccl_bind_pt==2.1.100 --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
``` ```
### 2. Configures OneAPI environment variables ### 2. Run pipeline parallel inference on multiple GPUs
For optimal performance, it is recommended to set several environment variables. We provide example usage as following:
- Run Llama-2-13b-chat-hf on two Intel Arc A770
```bash ```bash
source /opt/intel/oneapi/setvars.sh bash run_llama2_13b_arc_2_card.sh
``` ```
> [!NOTE] > **Note**: You could change `NUM_GPUS` to the number of GPUs you have on your machine.
> Please make sure you configure the environment variables for **Intel® oneAPI Base Toolkit's version == 2024.1.**.
### 3 Runtime Configurations
For optimal performance, it is recommended to set several environment variables. Please check out the suggestions based on your device.
<details>
<summary>For Intel Arc™ A-Series Graphics and Intel Data Center GPU Flex Series</summary>
```bash
export USE_XETLA=OFF
export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1
export SYCL_CACHE_PERSISTENT=1
```
</details>
<details>
<summary>For Intel Data Center GPU Max Series</summary>
```bash
export LD_PRELOAD=${LD_PRELOAD}:${CONDA_PREFIX}/lib/libtcmalloc.so
export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1
export SYCL_CACHE_PERSISTENT=1
export ENABLE_SDP_FUSION=1
```
> [!NOTE]
> Please note that `libtcmalloc.so` can be installed by `conda install -c conda-forge -y gperftools=2.10`.
</details>
### 4. Running examples
```
python ./generate.py --repo-id-or-model-path REPO_ID_OR_MODEL_PATH --prompt PROMPT --n-predict N_PREDICT --gpu-num GPU_NUM
```
Arguments info:
- `--repo-id-or-model-path REPO_ID_OR_MODEL_PATH`: argument defining the huggingface repo id for the Llama2 model (e.g. `meta-llama/Llama-2-7b-chat-hf` and `meta-llama/Llama-2-13b-chat-hf`) to be downloaded, or the path to the huggingface checkpoint folder. It is default to be `'meta-llama/Llama-2-7b-chat-hf'`.
- `--prompt PROMPT`: argument defining the prompt to be infered (with integrated prompt format for chat). It is default to be `'What is AI?'`.
- `--n-predict N_PREDICT`: argument defining the max number of tokens to predict. It is default to be `32`.
- `--gpu-num GPU_NUM`: argument defining the number of GPU to use. It is default to be `2`.
#### Sample Output #### Sample Output
##### [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf) ##### [meta-llama/Llama-2-13b-chat-hf](https://huggingface.co/meta-llama/Llama-2-13b-chat-hf)
```log ```log
Inference time: xxxx s Inference time: xxxx s
First token cost xxxx s and rest tokens cost average xxxx s
-------------------- Prompt -------------------- -------------------- Prompt --------------------
<s>[INST] <<SYS>> Once upon a time, there existed a little girl who liked to have adventures. She wanted to go to places and meet new people, and have fun
<</SYS>>
What is AI? [/INST]
-------------------- Output -------------------- -------------------- Output --------------------
[INST] <<SYS>> Once upon a time, there existed a little girl who liked to have adventures. She wanted to go to places and meet new people, and have fun. She was always asking her parents to take her on trips, but they were always too busy or too tired.
<</SYS>> One day, the little girl
What is AI? [/INST] Artificial intelligence (AI) is the broader field of research and development aimed at creating machines that can perform tasks that typically require human intelligence,
``` ```

View file

@ -19,34 +19,18 @@ import torch
import time import time
import argparse import argparse
from ipex_llm.transformers import AutoModelForCausalLM from ipex_llm.transformers import AutoModelForCausalLM, init_pipeline_parallel
from transformers import AutoTokenizer from transformers import AutoTokenizer
# you could tune the prompt based on your own model, init_pipeline_parallel()
# here the prompt tuning refers to https://huggingface.co/georgesung/llama2_7b_chat_uncensored#prompt-style
DEFAULT_SYSTEM_PROMPT = """\
"""
def get_prompt(message: str, chat_history: list[tuple[str, str]],
system_prompt: str) -> str:
texts = [f'<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n']
# The first user input is _not_ stripped
do_strip = False
for user_input, response in chat_history:
user_input = user_input.strip() if do_strip else user_input
do_strip = True
texts.append(f'{user_input} [/INST] {response.strip()} </s><s>[INST] ')
message = message.strip() if do_strip else message
texts.append(f'{message} [/INST]')
return ''.join(texts)
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for Llama2 model') parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for Llama2 model')
parser.add_argument('--repo-id-or-model-path', type=str, default="meta-llama/Llama-2-7b-chat-hf", parser.add_argument('--repo-id-or-model-path', type=str, default="meta-llama/Llama-2-13b-chat-hf",
help='The huggingface repo id for the Llama2 (e.g. `meta-llama/Llama-2-7b-chat-hf` and `meta-llama/Llama-2-13b-chat-hf`) to be downloaded' help='The huggingface repo id for the Llama2 (e.g. `meta-llama/Llama-2-7b-chat-hf` and `meta-llama/Llama-2-13b-chat-hf`) to be downloaded'
', or the path to the huggingface checkpoint folder') ', or the path to the huggingface checkpoint folder')
parser.add_argument('--prompt', type=str, default="What is AI?", parser.add_argument('--prompt', type=str, default="Once upon a time, there existed a little girl who liked to have adventures. She wanted to go to places and meet new people, and have fun",
help='Prompt to infer') help='Prompt to infer')
parser.add_argument('--n-predict', type=int, default=32, parser.add_argument('--n-predict', type=int, default=32,
help='Max tokens to predict') help='Max tokens to predict')
@ -66,35 +50,28 @@ if __name__ == '__main__':
# Load tokenizer # Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
local_rank = torch.distributed.get_rank()
# Generate predicted tokens # Generate predicted tokens
with torch.inference_mode(): with torch.inference_mode():
prompt = get_prompt(args.prompt, [], system_prompt=DEFAULT_SYSTEM_PROMPT) input_ids = tokenizer.encode(args.prompt, return_tensors="pt").to(f'xpu:{local_rank}')
input_ids = tokenizer.encode(prompt, return_tensors="pt").to('xpu:0')
# ipex_llm model needs a warmup, then inference time can be accurate # ipex_llm model needs a warmup, then inference time can be accurate
output = model.generate(input_ids, output = model.generate(input_ids,
do_sample=False,
max_new_tokens=args.n_predict)
output = model.generate(input_ids,
do_sample=False,
max_new_tokens=args.n_predict) max_new_tokens=args.n_predict)
# start inference # start inference
st = time.time() st = time.time()
# if your selected model is capable of utilizing previous key/value attentions
# to enhance decoding speed, but has `"use_cache": false` in its model config,
# it is important to set `use_cache=True` explicitly in the `generate` function
# to obtain optimal performance with IPEX-LLM INT4 optimizations
output = model.generate(input_ids, output = model.generate(input_ids,
do_sample=False,
max_new_tokens=args.n_predict) max_new_tokens=args.n_predict)
torch.xpu.synchronize() torch.xpu.synchronize()
end = time.time() end = time.time()
output = output.cpu() output = output.cpu()
output_str = tokenizer.decode(output[0], skip_special_tokens=True) if local_rank == args.gpu_num - 1:
print(f'Inference time: {end-st} s') output_str = tokenizer.decode(output[0], skip_special_tokens=True)
print('-'*20, 'Prompt', '-'*20) print(f'Inference time: {end-st} s')
print(prompt) print(f"First token cost {model.first_token_time:.4f} s and rest tokens cost average {model.rest_cost_mean:.4f} s")
print('-'*20, 'Output', '-'*20) print('-'*20, 'Prompt', '-'*20)
print(output_str) print(args.prompt)
print('-'*20, 'Output', '-'*20)
print(output_str)

View file

@ -0,0 +1,30 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
source /opt/intel/oneapi/setvars.sh
export MASTER_ADDR=127.0.0.1
export MASTER_PORT=9090
export FI_PROVIDER=tcp
export USE_XETLA=OFF
export OMP_NUM_THREADS=6
if [[ $KERNEL_VERSION != *"6.5"* ]]; then
export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1
fi
export TORCH_LLM_ALLREDUCE=0
NUM_GPUS=2 # number of used GPU
CCL_ZE_IPC_EXCHANGE=sockets torchrun --standalone --nnodes=1 --nproc-per-node $NUM_GPUS \
generate.py --repo-id-or-model-path 'meta-llama/Llama-2-13b-chat-hf' --gpu-num $NUM_GPUS

View file

@ -22,3 +22,4 @@ from .model import AutoModelForCausalLM, AutoModel, AutoModelForSeq2SeqLM, \
AutoModelForNextSentencePrediction, AutoModelForMultipleChoice, \ AutoModelForNextSentencePrediction, AutoModelForMultipleChoice, \
AutoModelForTokenClassification AutoModelForTokenClassification
from .modelling_bigdl import * from .modelling_bigdl import *
from .pipeline_parallel import init_pipeline_parallel

View file

@ -95,28 +95,6 @@ def save_low_bit(self, *args, **kwargs):
self.to(origin_device) self.to(origin_device)
def pipeline_parallel(model, pipeline_parallel_stages):
model_layers = ['model.embed_tokens']
for i in range(model.config.num_hidden_layers):
model_layers.append(f'model.layers.{i}')
model_layers = model_layers + ['model.norm', 'lm_head']
device_map = {}
split_len = len(model_layers) // pipeline_parallel_stages
for i in range(pipeline_parallel_stages):
device_map.update({key: f'xpu:{i}' for key in
model_layers[split_len * i: split_len * (i + 1)]})
if i == pipeline_parallel_stages - 1:
device_map.update({key: f'xpu:{i}' for key in
model_layers[split_len * (i + 1):]})
from accelerate import dispatch_model
model = dispatch_model(
model, device_map=device_map, skip_keys=["past_key_value", "past_key_values"],
)
return model
def _load_pre(): def _load_pre():
from transformers import GPTJModel from transformers import GPTJModel
from ipex_llm.transformers.models.gptj import gptj_model_new_init from ipex_llm.transformers.models.gptj import gptj_model_new_init
@ -377,8 +355,16 @@ class _BaseAutoModelClass:
invalidInputError(False, invalidInputError(False,
f"Please do not set speculative=True" f"Please do not set speculative=True"
f" when using pipeline_parallel_stages") f" when using pipeline_parallel_stages")
invalidInputError(torch.distributed.get_world_size() == pipeline_parallel_stages,
"Please make sure you've called `init_pipeline_parallel()` "
"and world size is the same as `pipeline_parallel_stages`")
from .pipeline_parallel import pipeline_parallel, pipeline_parallel_generate
model = pipeline_parallel(model, pipeline_parallel_stages) model = pipeline_parallel(model, pipeline_parallel_stages)
import types
# add pipeline_parallel_generate to pretrained model dynamically
model.pipeline_parallel_generate = types.MethodType(pipeline_parallel_generate,
model)
torch.distributed.barrier()
if speculative: if speculative:
from .speculative import speculative_generate, clear_benchmarks,\ from .speculative import speculative_generate, clear_benchmarks,\
_crop_past_key_values _crop_past_key_values

View file

@ -0,0 +1,195 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Some parts of this file is adapted from
# https://github.com/huggingface/transformers/blob/main/src/transformers/generation/utils.py
#
import torch
from torch import nn
import torch.distributed as dist
import os
import time
import numpy as np
from typing import Callable, List, Optional
from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteriaList
# patch GenerationMixin.generate
from transformers import GenerationMixin
original_generate = GenerationMixin.generate
class DummyLayer(nn.Module):
def __init__(self, *args):
super().__init__()
# to avoid AttributeError in https://github.com/intel-analytics/ipex-llm/blob/main/
# python/llm/src/ipex_llm/transformers/models/llama.py#L2076
self.weight = torch.randn(1,)
def forward(self, x):
return x
class Dummy_MLPLayer(nn.Module):
def __init__(self, *args):
super().__init__()
# to avoid AttributeError in https://github.com/intel-analytics/ipex-llm/blob/main/
# python/llm/src/ipex_llm/transformers/models/llama.py#L119
self.up_proj = DummyLayer()
def forward(self, x):
return x
class Dummy_DecoderLayer(nn.Module):
def __init__(self, *args):
super().__init__()
# to avoid AttributeError
self.input_layernorm = DummyLayer()
self.mlp = Dummy_MLPLayer()
def forward(self, hidden_states, past_key_value=None, use_cache=False, **kwargs):
outputs = (hidden_states,)
if use_cache:
outputs += (past_key_value,)
return outputs
def init_pipeline_parallel():
import oneccl_bindings_for_pytorch
os.environ["MASTER_ADDR"] = os.environ.get("MASTER_ADDR", "127.0.0.1")
os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", "29500")
dist.init_process_group('ccl')
def pipeline_parallel(model, pipeline_parallel_stages):
slice_size = (model.config.num_hidden_layers + pipeline_parallel_stages - 1) // \
pipeline_parallel_stages
local_rank = dist.get_rank()
layer_start = slice_size * local_rank
layer_end = layer_start + min(slice_size, model.config.num_hidden_layers - layer_start)
for i in range(model.config.num_hidden_layers):
if i < layer_start or i >= layer_end:
model._modules['model'].layers[i] = Dummy_DecoderLayer()
else:
# align layer_idx and len(past_key_values), otherwise abnormal output
model._modules['model'].layers[i].self_attn.layer_idx = i - layer_start
if local_rank != 0:
model._modules['model'].embed_tokens = DummyLayer()
if local_rank != pipeline_parallel_stages - 1:
model._modules['model'].norm = DummyLayer()
model._modules['lm_head'] = DummyLayer()
model.pipeline_parallel_stages = pipeline_parallel_stages
model = model.to(f'xpu:{local_rank}')
return model
@torch.no_grad()
def generate(
self,
inputs: Optional[torch.Tensor] = None,
generation_config: Optional[GenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]]=None,
synced_gpus: Optional[bool] = None,
assistant_model: Optional["PreTrainedModel"] = None,
streamer: Optional["BaseStreamer"] = None,
**kwargs,
):
if hasattr(self, 'pipeline_parallel_stages') and self.pipeline_parallel_stages > 1:
if generation_config is not None and generation_config.max_new_tokens is not None:
max_new_tokens = generation_config.max_new_tokens
else:
max_new_tokens = kwargs.get("max_new_tokens", None)
return self.pipeline_parallel_generate(inputs=inputs,
max_new_tokens=max_new_tokens,)
return original_generate(self,
inputs=inputs,
generation_config=generation_config,
logits_processor=logits_processor,
stopping_criteria=stopping_criteria,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
synced_gpus=synced_gpus,
assistant_model=assistant_model,
streamer=streamer,
**kwargs)
GenerationMixin.generate = generate
@torch.no_grad()
def pipeline_parallel_generate(self,
inputs: Optional[torch.Tensor] = None,
max_new_tokens: int = 32,
**kwargs):
local_rank = dist.get_rank()
pre_rank = (local_rank - 1) % self.pipeline_parallel_stages
next_rank = (local_rank + 1) % self.pipeline_parallel_stages
self.first_token_time = 0
self.next_token_time = []
_input_ids = None
_past_key_values = None
bs = inputs.shape[0]
output_ids = inputs.clone()
step = 0
while True:
if step >= max_new_tokens:
break
if _input_ids is None:
_input_ids = inputs
tic = time.time()
if local_rank == 0:
outputs = self(input_ids=_input_ids, inputs_embeds=None,
past_key_values=_past_key_values, use_cache=True)
else:
inputs_embeds = torch.empty(_input_ids.shape + (self.config.hidden_size,),
device=f'xpu:{local_rank}', dtype=torch.float32)
dist.recv(inputs_embeds, src=pre_rank)
outputs = self(input_ids=None, inputs_embeds=inputs_embeds,
past_key_values=_past_key_values, use_cache=True)
if local_rank == self.pipeline_parallel_stages - 1:
logits = outputs.logits
next_ids = torch.argmax(logits[:, -1:, :], dim=-1)
dist.broadcast(next_ids, src=local_rank)
else:
dist.send(outputs[0], dst=next_rank)
next_ids = torch.empty((bs, 1), device=f'xpu:{local_rank}', dtype=torch.int64)
dist.broadcast(next_ids, src=self.pipeline_parallel_stages - 1)
_input_ids = next_ids
output_ids = torch.cat([output_ids, next_ids], dim=-1)
_past_key_values = outputs.past_key_values
toc = time.time()
if step == 0:
self.first_token_time = toc - tic
else:
self.next_token_time.append(toc - tic)
step += 1
if self.device.type == 'xpu':
torch.xpu.synchronize()
self.rest_cost_mean = np.mean(self.next_token_time)
return output_ids