From 220151e2a109bae54c8c41f6b47d86c4bc4c25da Mon Sep 17 00:00:00 2001
From: binbin Deng <108676127+plusbang@users.noreply.github.com>
Date: Thu, 13 Jun 2024 10:00:23 +0800
Subject: [PATCH] Refactor pipeline parallel multi-stage implementation
(#11286)
---
.../GPU/Pipeline-Parallel-Inference/README.md | 86 ++------
.../Pipeline-Parallel-Inference/generate.py | 51 ++---
.../run_llama2_13b_arc_2_card.sh | 30 +++
.../llm/src/ipex_llm/transformers/__init__.py | 1 +
python/llm/src/ipex_llm/transformers/model.py | 32 +--
.../transformers/pipeline_parallel.py | 195 ++++++++++++++++++
6 files changed, 271 insertions(+), 124 deletions(-)
create mode 100644 python/llm/example/GPU/Pipeline-Parallel-Inference/run_llama2_13b_arc_2_card.sh
create mode 100644 python/llm/src/ipex_llm/transformers/pipeline_parallel.py
diff --git a/python/llm/example/GPU/Pipeline-Parallel-Inference/README.md b/python/llm/example/GPU/Pipeline-Parallel-Inference/README.md
index 1f51c5f9..c1ffdd96 100644
--- a/python/llm/example/GPU/Pipeline-Parallel-Inference/README.md
+++ b/python/llm/example/GPU/Pipeline-Parallel-Inference/README.md
@@ -5,90 +5,48 @@ This example demonstrates how to run IPEX-LLM optimized low-bit model vertically
## Requirements
To run this example with IPEX-LLM on Intel GPUs, we have some recommended requirements for your machine, please refer to [here](../README.md#recommended-requirements) for more information. For this particular example, you will need at least two GPUs on your machine.
-> [!NOTE]
-> To run IPEX-LLM on multiple Intel GPUs in pipeline parallel fashion, you will need to install **Intel® oneAPI Base Toolkit 2024.1**, which could be done through an offline installer:
-> ```bash
-> wget https://registrationcenter-download.intel.com/akdlm/IRC_NAS/fdc7a2bc-b7a8-47eb-8876-de6201297144/l_BaseKit_p_2024.1.0.596_offline.sh
->
-> sudo sh ./l_BaseKit_p_2024.1.0.596_offline.sh
-> ```
+## Verified Models
+- [Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf)
+- [Llama-2-13b-chat-hf](https://huggingface.co/meta-llama/Llama-2-13b-chat-hf)
+- [Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct)
## Example: Run pipeline parallel inference on multiple GPUs
+### 0. Prerequisites
+
+Please visit the [Install IPEX-LLM on Linux with Intel GPU](https://ipex-llm.readthedocs.io/en/latest/doc/LLM/Quickstart/install_linux_gpu.html), follow [Install Intel GPU Driver](https://ipex-llm.readthedocs.io/en/latest/doc/LLM/Quickstart/install_linux_gpu.html#install-intel-gpu-driver) and [Install oneAPI](https://ipex-llm.readthedocs.io/en/latest/doc/LLM/Quickstart/install_linux_gpu.html#install-oneapi) to install GPU driver and Intel® oneAPI Base Toolkit 2024.0.
+
### 1. Installation
```bash
conda create -n llm python=3.11
conda activate llm
-
+# below command will install intel_extension_for_pytorch==2.1.10+xpu as default
pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
-pip install torch==2.1.0.post2 torchvision==0.16.0.post2 torchaudio==2.1.0.post2 intel-extension-for-pytorch==2.1.30+xpu oneccl_bind_pt==2.1.300+xpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
+pip install oneccl_bind_pt==2.1.100 --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
```
-### 2. Configures OneAPI environment variables
+### 2. Run pipeline parallel inference on multiple GPUs
+
+For optimal performance, it is recommended to set several environment variables. We provide example usage as following:
+
+- Run Llama-2-13b-chat-hf on two Intel Arc A770
```bash
-source /opt/intel/oneapi/setvars.sh
+bash run_llama2_13b_arc_2_card.sh
```
-> [!NOTE]
-> Please make sure you configure the environment variables for **Intel® oneAPI Base Toolkit's version == 2024.1.**.
-
-### 3 Runtime Configurations
-
-For optimal performance, it is recommended to set several environment variables. Please check out the suggestions based on your device.
-
-
-
-For Intel Arc™ A-Series Graphics and Intel Data Center GPU Flex Series
-
-```bash
-export USE_XETLA=OFF
-export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1
-export SYCL_CACHE_PERSISTENT=1
-```
-
-
-
-
-
-For Intel Data Center GPU Max Series
-
-```bash
-export LD_PRELOAD=${LD_PRELOAD}:${CONDA_PREFIX}/lib/libtcmalloc.so
-export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1
-export SYCL_CACHE_PERSISTENT=1
-export ENABLE_SDP_FUSION=1
-```
-> [!NOTE]
-> Please note that `libtcmalloc.so` can be installed by `conda install -c conda-forge -y gperftools=2.10`.
-
-
-### 4. Running examples
-```
-python ./generate.py --repo-id-or-model-path REPO_ID_OR_MODEL_PATH --prompt PROMPT --n-predict N_PREDICT --gpu-num GPU_NUM
-```
-
-Arguments info:
-- `--repo-id-or-model-path REPO_ID_OR_MODEL_PATH`: argument defining the huggingface repo id for the Llama2 model (e.g. `meta-llama/Llama-2-7b-chat-hf` and `meta-llama/Llama-2-13b-chat-hf`) to be downloaded, or the path to the huggingface checkpoint folder. It is default to be `'meta-llama/Llama-2-7b-chat-hf'`.
-- `--prompt PROMPT`: argument defining the prompt to be infered (with integrated prompt format for chat). It is default to be `'What is AI?'`.
-- `--n-predict N_PREDICT`: argument defining the max number of tokens to predict. It is default to be `32`.
-- `--gpu-num GPU_NUM`: argument defining the number of GPU to use. It is default to be `2`.
+> **Note**: You could change `NUM_GPUS` to the number of GPUs you have on your machine.
#### Sample Output
-##### [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf)
+##### [meta-llama/Llama-2-13b-chat-hf](https://huggingface.co/meta-llama/Llama-2-13b-chat-hf)
```log
Inference time: xxxx s
+First token cost xxxx s and rest tokens cost average xxxx s
-------------------- Prompt --------------------
-[INST] <>
-
-<>
-
-What is AI? [/INST]
+Once upon a time, there existed a little girl who liked to have adventures. She wanted to go to places and meet new people, and have fun
-------------------- Output --------------------
-[INST] <>
+Once upon a time, there existed a little girl who liked to have adventures. She wanted to go to places and meet new people, and have fun. She was always asking her parents to take her on trips, but they were always too busy or too tired.
-<>
-
-What is AI? [/INST] Artificial intelligence (AI) is the broader field of research and development aimed at creating machines that can perform tasks that typically require human intelligence,
+One day, the little girl
```
\ No newline at end of file
diff --git a/python/llm/example/GPU/Pipeline-Parallel-Inference/generate.py b/python/llm/example/GPU/Pipeline-Parallel-Inference/generate.py
index ae3cedb1..5104c701 100644
--- a/python/llm/example/GPU/Pipeline-Parallel-Inference/generate.py
+++ b/python/llm/example/GPU/Pipeline-Parallel-Inference/generate.py
@@ -19,34 +19,18 @@ import torch
import time
import argparse
-from ipex_llm.transformers import AutoModelForCausalLM
+from ipex_llm.transformers import AutoModelForCausalLM, init_pipeline_parallel
from transformers import AutoTokenizer
-# you could tune the prompt based on your own model,
-# here the prompt tuning refers to https://huggingface.co/georgesung/llama2_7b_chat_uncensored#prompt-style
-DEFAULT_SYSTEM_PROMPT = """\
-"""
-
-def get_prompt(message: str, chat_history: list[tuple[str, str]],
- system_prompt: str) -> str:
- texts = [f'[INST] <>\n{system_prompt}\n<>\n\n']
- # The first user input is _not_ stripped
- do_strip = False
- for user_input, response in chat_history:
- user_input = user_input.strip() if do_strip else user_input
- do_strip = True
- texts.append(f'{user_input} [/INST] {response.strip()} [INST] ')
- message = message.strip() if do_strip else message
- texts.append(f'{message} [/INST]')
- return ''.join(texts)
+init_pipeline_parallel()
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for Llama2 model')
- parser.add_argument('--repo-id-or-model-path', type=str, default="meta-llama/Llama-2-7b-chat-hf",
+ parser.add_argument('--repo-id-or-model-path', type=str, default="meta-llama/Llama-2-13b-chat-hf",
help='The huggingface repo id for the Llama2 (e.g. `meta-llama/Llama-2-7b-chat-hf` and `meta-llama/Llama-2-13b-chat-hf`) to be downloaded'
', or the path to the huggingface checkpoint folder')
- parser.add_argument('--prompt', type=str, default="What is AI?",
+ parser.add_argument('--prompt', type=str, default="Once upon a time, there existed a little girl who liked to have adventures. She wanted to go to places and meet new people, and have fun",
help='Prompt to infer')
parser.add_argument('--n-predict', type=int, default=32,
help='Max tokens to predict')
@@ -66,35 +50,28 @@ if __name__ == '__main__':
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
+ local_rank = torch.distributed.get_rank()
# Generate predicted tokens
with torch.inference_mode():
- prompt = get_prompt(args.prompt, [], system_prompt=DEFAULT_SYSTEM_PROMPT)
- input_ids = tokenizer.encode(prompt, return_tensors="pt").to('xpu:0')
+ input_ids = tokenizer.encode(args.prompt, return_tensors="pt").to(f'xpu:{local_rank}')
# ipex_llm model needs a warmup, then inference time can be accurate
output = model.generate(input_ids,
- do_sample=False,
- max_new_tokens=args.n_predict)
- output = model.generate(input_ids,
- do_sample=False,
max_new_tokens=args.n_predict)
# start inference
st = time.time()
- # if your selected model is capable of utilizing previous key/value attentions
- # to enhance decoding speed, but has `"use_cache": false` in its model config,
- # it is important to set `use_cache=True` explicitly in the `generate` function
- # to obtain optimal performance with IPEX-LLM INT4 optimizations
output = model.generate(input_ids,
- do_sample=False,
max_new_tokens=args.n_predict)
torch.xpu.synchronize()
end = time.time()
output = output.cpu()
- output_str = tokenizer.decode(output[0], skip_special_tokens=True)
- print(f'Inference time: {end-st} s')
- print('-'*20, 'Prompt', '-'*20)
- print(prompt)
- print('-'*20, 'Output', '-'*20)
- print(output_str)
+ if local_rank == args.gpu_num - 1:
+ output_str = tokenizer.decode(output[0], skip_special_tokens=True)
+ print(f'Inference time: {end-st} s')
+ print(f"First token cost {model.first_token_time:.4f} s and rest tokens cost average {model.rest_cost_mean:.4f} s")
+ print('-'*20, 'Prompt', '-'*20)
+ print(args.prompt)
+ print('-'*20, 'Output', '-'*20)
+ print(output_str)
diff --git a/python/llm/example/GPU/Pipeline-Parallel-Inference/run_llama2_13b_arc_2_card.sh b/python/llm/example/GPU/Pipeline-Parallel-Inference/run_llama2_13b_arc_2_card.sh
new file mode 100644
index 00000000..5924aada
--- /dev/null
+++ b/python/llm/example/GPU/Pipeline-Parallel-Inference/run_llama2_13b_arc_2_card.sh
@@ -0,0 +1,30 @@
+#
+# Copyright 2016 The BigDL Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+source /opt/intel/oneapi/setvars.sh
+export MASTER_ADDR=127.0.0.1
+export MASTER_PORT=9090
+export FI_PROVIDER=tcp
+export USE_XETLA=OFF
+export OMP_NUM_THREADS=6
+if [[ $KERNEL_VERSION != *"6.5"* ]]; then
+ export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1
+fi
+export TORCH_LLM_ALLREDUCE=0
+
+NUM_GPUS=2 # number of used GPU
+CCL_ZE_IPC_EXCHANGE=sockets torchrun --standalone --nnodes=1 --nproc-per-node $NUM_GPUS \
+ generate.py --repo-id-or-model-path 'meta-llama/Llama-2-13b-chat-hf' --gpu-num $NUM_GPUS
diff --git a/python/llm/src/ipex_llm/transformers/__init__.py b/python/llm/src/ipex_llm/transformers/__init__.py
index 02d51f2f..e95e7704 100644
--- a/python/llm/src/ipex_llm/transformers/__init__.py
+++ b/python/llm/src/ipex_llm/transformers/__init__.py
@@ -22,3 +22,4 @@ from .model import AutoModelForCausalLM, AutoModel, AutoModelForSeq2SeqLM, \
AutoModelForNextSentencePrediction, AutoModelForMultipleChoice, \
AutoModelForTokenClassification
from .modelling_bigdl import *
+from .pipeline_parallel import init_pipeline_parallel
diff --git a/python/llm/src/ipex_llm/transformers/model.py b/python/llm/src/ipex_llm/transformers/model.py
index 70c2f0d9..8aa002ee 100644
--- a/python/llm/src/ipex_llm/transformers/model.py
+++ b/python/llm/src/ipex_llm/transformers/model.py
@@ -95,28 +95,6 @@ def save_low_bit(self, *args, **kwargs):
self.to(origin_device)
-def pipeline_parallel(model, pipeline_parallel_stages):
- model_layers = ['model.embed_tokens']
- for i in range(model.config.num_hidden_layers):
- model_layers.append(f'model.layers.{i}')
- model_layers = model_layers + ['model.norm', 'lm_head']
-
- device_map = {}
- split_len = len(model_layers) // pipeline_parallel_stages
- for i in range(pipeline_parallel_stages):
- device_map.update({key: f'xpu:{i}' for key in
- model_layers[split_len * i: split_len * (i + 1)]})
- if i == pipeline_parallel_stages - 1:
- device_map.update({key: f'xpu:{i}' for key in
- model_layers[split_len * (i + 1):]})
-
- from accelerate import dispatch_model
- model = dispatch_model(
- model, device_map=device_map, skip_keys=["past_key_value", "past_key_values"],
- )
- return model
-
-
def _load_pre():
from transformers import GPTJModel
from ipex_llm.transformers.models.gptj import gptj_model_new_init
@@ -377,8 +355,16 @@ class _BaseAutoModelClass:
invalidInputError(False,
f"Please do not set speculative=True"
f" when using pipeline_parallel_stages")
+ invalidInputError(torch.distributed.get_world_size() == pipeline_parallel_stages,
+ "Please make sure you've called `init_pipeline_parallel()` "
+ "and world size is the same as `pipeline_parallel_stages`")
+ from .pipeline_parallel import pipeline_parallel, pipeline_parallel_generate
model = pipeline_parallel(model, pipeline_parallel_stages)
-
+ import types
+ # add pipeline_parallel_generate to pretrained model dynamically
+ model.pipeline_parallel_generate = types.MethodType(pipeline_parallel_generate,
+ model)
+ torch.distributed.barrier()
if speculative:
from .speculative import speculative_generate, clear_benchmarks,\
_crop_past_key_values
diff --git a/python/llm/src/ipex_llm/transformers/pipeline_parallel.py b/python/llm/src/ipex_llm/transformers/pipeline_parallel.py
new file mode 100644
index 00000000..d750cc1b
--- /dev/null
+++ b/python/llm/src/ipex_llm/transformers/pipeline_parallel.py
@@ -0,0 +1,195 @@
+#
+# Copyright 2016 The BigDL Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Some parts of this file is adapted from
+# https://github.com/huggingface/transformers/blob/main/src/transformers/generation/utils.py
+#
+
+import torch
+from torch import nn
+import torch.distributed as dist
+import os
+import time
+import numpy as np
+from typing import Callable, List, Optional
+from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteriaList
+
+# patch GenerationMixin.generate
+from transformers import GenerationMixin
+original_generate = GenerationMixin.generate
+
+
+class DummyLayer(nn.Module):
+ def __init__(self, *args):
+ super().__init__()
+ # to avoid AttributeError in https://github.com/intel-analytics/ipex-llm/blob/main/
+ # python/llm/src/ipex_llm/transformers/models/llama.py#L2076
+ self.weight = torch.randn(1,)
+
+ def forward(self, x):
+ return x
+
+
+class Dummy_MLPLayer(nn.Module):
+ def __init__(self, *args):
+ super().__init__()
+ # to avoid AttributeError in https://github.com/intel-analytics/ipex-llm/blob/main/
+ # python/llm/src/ipex_llm/transformers/models/llama.py#L119
+ self.up_proj = DummyLayer()
+
+ def forward(self, x):
+ return x
+
+
+class Dummy_DecoderLayer(nn.Module):
+ def __init__(self, *args):
+ super().__init__()
+ # to avoid AttributeError
+ self.input_layernorm = DummyLayer()
+ self.mlp = Dummy_MLPLayer()
+
+ def forward(self, hidden_states, past_key_value=None, use_cache=False, **kwargs):
+ outputs = (hidden_states,)
+ if use_cache:
+ outputs += (past_key_value,)
+ return outputs
+
+
+def init_pipeline_parallel():
+ import oneccl_bindings_for_pytorch
+ os.environ["MASTER_ADDR"] = os.environ.get("MASTER_ADDR", "127.0.0.1")
+ os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", "29500")
+ dist.init_process_group('ccl')
+
+
+def pipeline_parallel(model, pipeline_parallel_stages):
+ slice_size = (model.config.num_hidden_layers + pipeline_parallel_stages - 1) // \
+ pipeline_parallel_stages
+
+ local_rank = dist.get_rank()
+ layer_start = slice_size * local_rank
+ layer_end = layer_start + min(slice_size, model.config.num_hidden_layers - layer_start)
+
+ for i in range(model.config.num_hidden_layers):
+ if i < layer_start or i >= layer_end:
+ model._modules['model'].layers[i] = Dummy_DecoderLayer()
+ else:
+ # align layer_idx and len(past_key_values), otherwise abnormal output
+ model._modules['model'].layers[i].self_attn.layer_idx = i - layer_start
+
+ if local_rank != 0:
+ model._modules['model'].embed_tokens = DummyLayer()
+ if local_rank != pipeline_parallel_stages - 1:
+ model._modules['model'].norm = DummyLayer()
+ model._modules['lm_head'] = DummyLayer()
+
+ model.pipeline_parallel_stages = pipeline_parallel_stages
+ model = model.to(f'xpu:{local_rank}')
+ return model
+
+
+@torch.no_grad()
+def generate(
+ self,
+ inputs: Optional[torch.Tensor] = None,
+ generation_config: Optional[GenerationConfig] = None,
+ logits_processor: Optional[LogitsProcessorList] = None,
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
+ prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]]=None,
+ synced_gpus: Optional[bool] = None,
+ assistant_model: Optional["PreTrainedModel"] = None,
+ streamer: Optional["BaseStreamer"] = None,
+ **kwargs,
+):
+ if hasattr(self, 'pipeline_parallel_stages') and self.pipeline_parallel_stages > 1:
+ if generation_config is not None and generation_config.max_new_tokens is not None:
+ max_new_tokens = generation_config.max_new_tokens
+ else:
+ max_new_tokens = kwargs.get("max_new_tokens", None)
+ return self.pipeline_parallel_generate(inputs=inputs,
+ max_new_tokens=max_new_tokens,)
+
+ return original_generate(self,
+ inputs=inputs,
+ generation_config=generation_config,
+ logits_processor=logits_processor,
+ stopping_criteria=stopping_criteria,
+ prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
+ synced_gpus=synced_gpus,
+ assistant_model=assistant_model,
+ streamer=streamer,
+ **kwargs)
+
+GenerationMixin.generate = generate
+
+
+@torch.no_grad()
+def pipeline_parallel_generate(self,
+ inputs: Optional[torch.Tensor] = None,
+ max_new_tokens: int = 32,
+ **kwargs):
+ local_rank = dist.get_rank()
+ pre_rank = (local_rank - 1) % self.pipeline_parallel_stages
+ next_rank = (local_rank + 1) % self.pipeline_parallel_stages
+
+ self.first_token_time = 0
+ self.next_token_time = []
+
+ _input_ids = None
+ _past_key_values = None
+ bs = inputs.shape[0]
+ output_ids = inputs.clone()
+
+ step = 0
+ while True:
+ if step >= max_new_tokens:
+ break
+
+ if _input_ids is None:
+ _input_ids = inputs
+
+ tic = time.time()
+ if local_rank == 0:
+ outputs = self(input_ids=_input_ids, inputs_embeds=None,
+ past_key_values=_past_key_values, use_cache=True)
+ else:
+ inputs_embeds = torch.empty(_input_ids.shape + (self.config.hidden_size,),
+ device=f'xpu:{local_rank}', dtype=torch.float32)
+ dist.recv(inputs_embeds, src=pre_rank)
+ outputs = self(input_ids=None, inputs_embeds=inputs_embeds,
+ past_key_values=_past_key_values, use_cache=True)
+
+ if local_rank == self.pipeline_parallel_stages - 1:
+ logits = outputs.logits
+ next_ids = torch.argmax(logits[:, -1:, :], dim=-1)
+ dist.broadcast(next_ids, src=local_rank)
+ else:
+ dist.send(outputs[0], dst=next_rank)
+ next_ids = torch.empty((bs, 1), device=f'xpu:{local_rank}', dtype=torch.int64)
+ dist.broadcast(next_ids, src=self.pipeline_parallel_stages - 1)
+
+ _input_ids = next_ids
+ output_ids = torch.cat([output_ids, next_ids], dim=-1)
+ _past_key_values = outputs.past_key_values
+ toc = time.time()
+ if step == 0:
+ self.first_token_time = toc - tic
+ else:
+ self.next_token_time.append(toc - tic)
+ step += 1
+ if self.device.type == 'xpu':
+ torch.xpu.synchronize()
+ self.rest_cost_mean = np.mean(self.next_token_time)
+ return output_ids