Add more qwen1.5 and qwen2 support for pipeline parallel inference (#11423)

This commit is contained in:
binbin Deng 2024-06-25 15:49:32 +08:00 committed by GitHub
parent aacc1fd8c0
commit e473b8d946
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 74 additions and 2 deletions

View file

@ -9,9 +9,12 @@ To run this example with IPEX-LLM on Intel GPUs, we have some recommended requir
- [meta-llama/Llama-2-7b-chat-hf](./run_llama_arc_2_card.sh) - [meta-llama/Llama-2-7b-chat-hf](./run_llama_arc_2_card.sh)
- [meta-llama/Llama-2-13b-chat-hf](./run_llama_arc_2_card.sh) - [meta-llama/Llama-2-13b-chat-hf](./run_llama_arc_2_card.sh)
- [meta-llama/Meta-Llama-3-8B-Instruct](./run_llama_arc_2_card.sh) - [meta-llama/Meta-Llama-3-8B-Instruct](./run_llama_arc_2_card.sh)
- [Qwen/Qwen2-7B-Instruct](./run_qwen2_arc_2_card.sh)
- [Qwen/Qwen1.5-7B-Chat](./run_qwen1.5_arc_2_card.sh) - [Qwen/Qwen1.5-7B-Chat](./run_qwen1.5_arc_2_card.sh)
- [Qwen/Qwen1.5-14B-Chat](./run_qwen1.5_arc_2_card.sh) - [Qwen/Qwen1.5-14B-Chat](./run_qwen1.5_arc_2_card.sh)
- [Qwen/Qwen1.5-32B-Chat](./run_qwen1.5_arc_2_card.sh) - [Qwen/Qwen1.5-32B-Chat](./run_qwen1.5_arc_2_card.sh)
- [Qwen/Qwen1.5-MoE-A2.7B-Chat](./run_qwen1.5_arc_2_card.sh)
- [Qwen/CodeQwen1.5-7B-Chat](./run_qwen1.5_arc_2_card.sh)
- [THUDM/chatglm3-6b](./run_chatglm_arc_2_card.sh) - [THUDM/chatglm3-6b](./run_chatglm_arc_2_card.sh)
- [baichuan-inc/Baichuan2-7B-Chat](./run_baichuan2_arc_2_card.sh) - [baichuan-inc/Baichuan2-7B-Chat](./run_baichuan2_arc_2_card.sh)
- [baichuan-inc/Baichuan2-13B-Chat](./run_baichuan2_arc_2_card.sh) - [baichuan-inc/Baichuan2-13B-Chat](./run_baichuan2_arc_2_card.sh)
@ -67,10 +70,26 @@ bash run_llama_arc_2_card.sh
</details> </details>
<details>
<summary> Show Qwen2 example </summary>
#### Run Qwen2-7B-Instruct on two Intel Arc A770
You could specify `--repo-id-or-model-path` in the test script to be the huggingface repo id for Qwen2 to be downloaded, or the path to the huggingface checkpoint folder. Besides, you could change `NUM_GPUS` to the number of GPUs you have on your machine.
```bash
pip install transformers==4.37.0
bash run_qwen2_arc_2_card.sh
```
</details>
</details>
<details> <details>
<summary> Show Qwen1.5 example </summary> <summary> Show Qwen1.5 example </summary>
#### Run Qwen1.5-7B-Chat / Qwen1.5-14B-Chat / Qwen1.5-32B-Chat on two Intel Arc A770 #### Run Qwen1.5-7B-Chat / Qwen1.5-14B-Chat / Qwen1.5-32B-Chat / CodeQwen1.5-7B-Chat on two Intel Arc A770
You could specify `--repo-id-or-model-path` in the test script to be the huggingface repo id for Qwen1.5 to be downloaded, or the path to the huggingface checkpoint folder. Besides, you could change `NUM_GPUS` to the number of GPUs you have on your machine. You could specify `--repo-id-or-model-path` in the test script to be the huggingface repo id for Qwen1.5 to be downloaded, or the path to the huggingface checkpoint folder. Besides, you could change `NUM_GPUS` to the number of GPUs you have on your machine.
@ -79,6 +98,15 @@ pip install transformers==4.37.0
bash run_qwen1.5_arc_2_card.sh bash run_qwen1.5_arc_2_card.sh
``` ```
#### Run Qwen1.5-MoE-A2.7B-Chat on two Intel Arc A770
You could specify `--repo-id-or-model-path` in the test script to be the huggingface repo id for Qwen1.5-MoE to be downloaded, or the path to the huggingface checkpoint folder. Besides, you could change `NUM_GPUS` to the number of GPUs you have on your machine.
```bash
pip install transformers==4.40.0 trl==0.8.1
bash run_qwen1.5_arc_2_card.sh
```
</details> </details>
</details> </details>

View file

@ -38,3 +38,11 @@ CCL_ZE_IPC_EXCHANGE=sockets torchrun --standalone --nnodes=1 --nproc-per-node $N
# # To run Qwen1.5-32B-Chat # # To run Qwen1.5-32B-Chat
# CCL_ZE_IPC_EXCHANGE=sockets torchrun --standalone --nnodes=1 --nproc-per-node $NUM_GPUS \ # CCL_ZE_IPC_EXCHANGE=sockets torchrun --standalone --nnodes=1 --nproc-per-node $NUM_GPUS \
# generate.py --repo-id-or-model-path 'Qwen/Qwen1.5-32B-Chat' --gpu-num $NUM_GPUS # generate.py --repo-id-or-model-path 'Qwen/Qwen1.5-32B-Chat' --gpu-num $NUM_GPUS
# # To run Qwen1.5-MoE-A2.7B-Chat
# CCL_ZE_IPC_EXCHANGE=sockets torchrun --standalone --nnodes=1 --nproc-per-node $NUM_GPUS \
# generate.py --repo-id-or-model-path 'Qwen/Qwen1.5-MoE-A2.7B-Chat' --gpu-num $NUM_GPUS
# # To run CodeQwen1.5-7B-Chat
# CCL_ZE_IPC_EXCHANGE=sockets torchrun --standalone --nnodes=1 --nproc-per-node $NUM_GPUS \
# generate.py --repo-id-or-model-path 'Qwen/CodeQwen1.5-7B-Chat' --gpu-num $NUM_GPUS

View file

@ -0,0 +1,32 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
source /opt/intel/oneapi/setvars.sh
export MASTER_ADDR=127.0.0.1
export MASTER_PORT=9090
export FI_PROVIDER=tcp
export USE_XETLA=OFF
export OMP_NUM_THREADS=6
if [[ $KERNEL_VERSION != *"6.5"* ]]; then
export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1
fi
export TORCH_LLM_ALLREDUCE=0
NUM_GPUS=2 # number of used GPU
# To run Qwen2-7B-Instruct
CCL_ZE_IPC_EXCHANGE=sockets torchrun --standalone --nnodes=1 --nproc-per-node $NUM_GPUS \
generate.py --repo-id-or-model-path 'Qwen/Qwen2-7B-Instruct' --gpu-num $NUM_GPUS

View file

@ -72,7 +72,8 @@ def qwen2moe_model_forward(
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
): ):
use_cache = use_cache if use_cache is not None else self.config.use_cache use_cache = use_cache if use_cache is not None else self.config.use_cache
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.shared_expert.up_proj, input_ids) input = input_ids if input_ids is not None else inputs_embeds
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.shared_expert.up_proj, input)
if use_cache: if use_cache:
if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache): if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values) past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)

View file

@ -24,6 +24,7 @@ import os
import time import time
import numpy as np import numpy as np
from typing import Callable, List, Optional from typing import Callable, List, Optional
from types import SimpleNamespace
from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteriaList from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteriaList
from ipex_llm.utils.common import invalidInputError from ipex_llm.utils.common import invalidInputError
import logging import logging
@ -52,6 +53,8 @@ class Dummy_MLPLayer(nn.Module):
# python/llm/src/ipex_llm/transformers/models/llama.py#L119 # python/llm/src/ipex_llm/transformers/models/llama.py#L119
self.up_proj = DummyLayer() self.up_proj = DummyLayer()
self.down_proj = DummyLayer() self.down_proj = DummyLayer()
self.shared_expert = SimpleNamespace()
self.shared_expert.up_proj = DummyLayer()
def forward(self, x): def forward(self, x):
return x return x