Support pipeline parallel for qwen-vl (#11503)

This commit is contained in:
binbin Deng 2024-07-04 18:03:57 +08:00 committed by GitHub
parent 57b8adb189
commit 60de428b37
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 370 additions and 11 deletions

View file

@ -14,6 +14,7 @@ To run this example with IPEX-LLM on Intel GPUs, we have some recommended requir
- [Qwen/Qwen1.5-14B-Chat](./run_qwen1.5_arc_2_card.sh)
- [Qwen/Qwen1.5-32B-Chat](./run_qwen1.5_arc_2_card.sh)
- [Qwen/Qwen1.5-MoE-A2.7B-Chat](./run_qwen1.5_arc_2_card.sh)
- [Qwen/Qwen-VL-Chat](./run_qwen_vl_arc_2_card.sh)
- [Qwen/CodeQwen1.5-7B-Chat](./run_qwen1.5_arc_2_card.sh)
- [THUDM/glm-4-9b-chat](./run_chatglm_arc_2_card.sh)
- [THUDM/chatglm3-6b](./run_chatglm_arc_2_card.sh)
@ -114,6 +115,22 @@ bash run_qwen1.5_arc_2_card.sh
</details>
<details>
<summary> Show Qwen-VL example </summary>
#### Run Qwen-VL-Chat on two Intel Arc A770
You could specify `--repo-id-or-model-path` in the test script to be the huggingface repo id for Qwen-VL to be downloaded, or the path to the huggingface checkpoint folder. Besides, you could change `NUM_GPUS` to the number of GPUs you have on your machine.
```bash
pip install transformers==4.32.0 tiktoken einops transformers_stream_generator==0.0.4 scipy torchvision pillow tensorboard matplotlib
bash run_qwen_vl_arc_2_card.sh
```
</details>
</details>
<details>
<summary> Show chatglm example </summary>
@ -250,3 +267,11 @@ Once upon a time, there existed a little girl who liked to have adventures. She
One day, the little girl
```
#### [Qwen/Qwen-VL-Chat](https://huggingface.co/Qwen/Qwen-VL-Chat)
```log
-------------------- Input --------------------
Message: [{'image': 'http://farm6.staticflickr.com/5268/5602445367_3504763978_z.jpg'}, {'text': '这是什么?'}]
-------------------- Output --------------------
这是一张图片,展现了一个穿着粉色条纹连衣裙的小女孩,她正拿着一只穿粉色裙子的白色玩具小熊。
```

View file

@ -0,0 +1,76 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import argparse
import os
import torch
import time
from transformers import AutoTokenizer
from ipex_llm.transformers import AutoModelForCausalLM, init_pipeline_parallel
init_pipeline_parallel()
torch.manual_seed(1234)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Predict Tokens using `chat()` API for large vision language model')
parser.add_argument('--repo-id-or-model-path', type=str, default="Qwen/Qwen-VL-Chat",
help='The huggingface repo id for the Qwen-VL-Chat model to be downloaded'
', or the path to the huggingface checkpoint folder')
parser.add_argument('--image-url-or-path', type=str,
default="http://farm6.staticflickr.com/5268/5602445367_3504763978_z.jpg",
help='The URL or path to the image to infer')
parser.add_argument('--prompt', type=str, default="这是什么?",
help='Prompt to infer')
parser.add_argument('--n-predict', type=int, default=32,
help='Max tokens to predict')
parser.add_argument('--low-bit', type=str, default='sym_int4', help='The quantization type the model will convert to.')
parser.add_argument('--gpu-num', type=int, default=2, help='GPU number to use')
args = parser.parse_args()
model_path = args.repo_id_or_model_path
image_path = args.image_url_or_path
# Load model
# For successful IPEX-LLM optimization on Qwen-VL-Chat, skip the 'c_fc' and 'out_proj' modules during optimization
# When running LLMs on Intel iGPUs for Windows users, we recommend setting `cpu_embedding=True` in the from_pretrained function.
# This will allow the memory-intensive embedding layer to utilize the CPU instead of iGPU.
model = AutoModelForCausalLM.from_pretrained(model_path,
load_in_low_bit=args.low_bit,
optimize_model=True,
trust_remote_code=True,
use_cache=True,
torch_dtype=torch.float32,
modules_to_not_convert=['c_fc', 'out_proj'],
pipeline_parallel_stages=args.gpu_num)
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
local_rank = torch.distributed.get_rank()
all_input = [{'image': args.image_url_or_path}, {'text': args.prompt}]
input_list = [_input for _input in all_input if list(_input.values())[0] != '']
query = tokenizer.from_list_format(input_list)
with torch.inference_mode():
response, _ = model.chat(tokenizer, query=query, history=[])
torch.xpu.synchronize()
if local_rank == args.gpu_num - 1:
print('-'*20, 'Input', '-'*20)
print(f'Message: {all_input}')
print('-'*20, 'Output', '-'*20)
print(response)

View file

@ -0,0 +1,32 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
source /opt/intel/oneapi/setvars.sh
export MASTER_ADDR=127.0.0.1
export MASTER_PORT=9090
export FI_PROVIDER=tcp
export USE_XETLA=OFF
export OMP_NUM_THREADS=6
if [[ $KERNEL_VERSION != *"6.5"* ]]; then
export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1
fi
export TORCH_LLM_ALLREDUCE=0
NUM_GPUS=2 # number of used GPU
# To run Qwen-VL-Chat
CCL_ZE_IPC_EXCHANGE=sockets torchrun --standalone --nnodes=1 --nproc-per-node $NUM_GPUS \
chat.py --repo-id-or-model-path 'Qwen/Qwen-VL-Chat' --gpu-num $NUM_GPUS --low-bit 'sym_int4'

View file

@ -1269,10 +1269,14 @@ def _optimize_post(model, lightweight_bmm=False):
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)
from ipex_llm.transformers.models.qwen_vl import qwen_attention_forward_vl
from ipex_llm.transformers.models.qwen_vl import qwen_vl_model_forward
convert_forward(model,
module.QWenAttention,
qwen_attention_forward_vl
)
convert_forward(model,
module.QWenModel,
qwen_vl_model_forward)
else:
# for Qwen-7B and Qwen-14B
modeling_module_name = model.__class__.__module__

View file

@ -33,7 +33,8 @@ from transformers.utils import logging
from ipex_llm.transformers.models.utils import extend_kv_cache, init_kv_cache, append_kv_cache
from ipex_llm.transformers.models.utils import rotate_half
from ipex_llm.transformers.models.utils import use_sdp
from transformers.modeling_outputs import BaseModelOutputWithPast
from ipex_llm.utils.common import invalidInputError
import os
KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))
@ -243,3 +244,209 @@ def qwen_vl_vision_transformer_forward(self, x: torch.Tensor):
x = x @ self.proj
return x
def qwen_vl_model_forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
# bigdl-llm change starts
input = input_ids if input_ids is not None else inputs_embeds
# bigdl-llm change ends
if past_key_values is None and torch.any(input == self.config.visual['image_start_id']):
bos_pos = torch.where(input == self.config.visual['image_start_id'])
eos_pos = torch.where(input == self.config.visual['image_start_id'] + 1)
invalidInputError((bos_pos[0] == eos_pos[0]).all(),
'bos_pos[0] should be same as eos_pos[0]')
img_pos = torch.stack((bos_pos[0], bos_pos[1], eos_pos[1]), dim=1)
images = []
for i, a, b in img_pos:
image = input[i][a + 1: b - 1].tolist()
image = image[: image.index(self.config.visual['image_start_id'] + 2)]
images.append(bytes(image).decode('utf-8'))
images = self.visual.encode(images)
invalidInputError(images.shape[0] == len(images),
'images.shape[0] should be same as len(images)')
fake_images = None
elif self.training:
fake_images = torch.zeros(1, 3, 224, 224).to(
dtype=self.visual.conv1.weight.dtype, device=self.visual.conv1.weight.device)
images = self.visual(fake_images)
else:
fake_images = None
images = None
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
if input_ids is not None and inputs_embeds is not None:
invalidInputError(False,
"You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
batch_size = input_ids.shape[0]
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
batch_size = inputs_embeds.shape[0]
else:
invalidInputError(False, "You have to specify either input_ids or inputs_embeds")
device = input_ids.device if input_ids is not None else inputs_embeds.device
if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, input_shape[-1])
if position_ids is not None:
position_ids = position_ids.view(-1, input_shape[-1])
if past_key_values is None:
past_length = 0
past_key_values = tuple([None] * len(self.h))
else:
past_length = past_key_values[0][0].size(-2)
if position_ids is None:
position_ids = torch.arange(
past_length,
input_shape[-1] + past_length,
dtype=torch.long,
device=device,
)
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
encoder_attention_mask = None
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
if inputs_embeds is None:
inputs_embeds = self.wte(input_ids)
if batch_size <= 0:
invalidInputError(False, "batch_size has to be defined and > 0")
attention_mask = self._prepare_decoder_attention_mask(
attention_mask, input_shape, inputs_embeds, past_length
)
hidden_states = inputs_embeds
kv_seq_len = hidden_states.size()[1]
if past_key_values[0] is not None:
# past key values[0][0] shape: bs * seq_len * head_num * dim
kv_seq_len += past_key_values[0][0].shape[1]
if (
self.use_dynamic_ntk
and kv_seq_len == hidden_states.size()[1]
and not self.training
):
context_value = math.log(kv_seq_len / self.seq_length, 2) + 1
ntk_alpha = 2 ** math.ceil(context_value) - 1
ntk_alpha = max(ntk_alpha, 1)
else:
ntk_alpha = self.rotary_emb._ntk_alpha_cached
rotary_pos_emb = self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha)
for idx in range(len(rotary_pos_emb)):
rotary_pos_emb[idx] = rotary_pos_emb[idx].to(hidden_states.device)
hidden_states = self.drop(hidden_states).clone()
if fake_images is not None:
hidden_states = hidden_states + images.mean()*0
elif images is not None:
for idx, (i, a, b) in enumerate(img_pos):
hidden_states[i][a + 1: b] = images[idx]
output_shape = input_shape + (hidden_states.size(-1),)
if self.gradient_checkpointing and self.training:
if use_cache:
use_cache = False
presents = () if use_cache else None
all_self_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, use_cache, output_attentions)
return custom_forward
outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
rotary_pos_emb,
self.registered_causal_mask,
None,
attention_mask,
head_mask[i],
encoder_hidden_states,
encoder_attention_mask,
)
else:
outputs = block(
hidden_states,
layer_past=layer_past,
rotary_pos_emb=rotary_pos_emb,
registered_causal_mask=self.registered_causal_mask,
attention_mask=attention_mask,
head_mask=head_mask[i],
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
hidden_states = outputs[0]
if use_cache is True:
presents = presents + (outputs[1],)
if output_attentions:
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
hidden_states = self.ln_f(hidden_states)
hidden_states = hidden_states.view(output_shape)
# Add last hidden state
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(
v for v in [hidden_states, presents, all_hidden_states] if v is not None
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=presents,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)

View file

@ -123,9 +123,20 @@ def pipeline_parallel(model, pipeline_parallel_stages):
layer_start = slice_size * local_rank
layer_end = layer_start + min(slice_size, num_layers - layer_start)
if model.config.architectures is not None \
and model.config.architectures[0] in ["ChatGLMModel", "ChatGLMForConditionalGeneration"]:
# for chatglm3-6b
if model.config.model_type == "qwen" and hasattr(model.config, "visual"):
# for Qwen-VL-Chat
for i in range(num_layers):
if i < layer_start or i >= layer_end:
model._modules['transformer'].h[i] = Dummy_DecoderLayer()
if local_rank != 0:
model._modules['transformer'].wte = DummyLayer()
model._modules['transformer'].drop = DummyLayer()
if local_rank != pipeline_parallel_stages - 1:
model._modules['transformer'].ln_f = DummyLayer()
model._modules['ln_f'] = DummyLayer()
model._modules['lm_head'] = DummyLayer()
elif model.config.model_type == "chatglm":
# for chatglm3-6b, glm-4-9b-chat
for i in range(num_layers):
if i < layer_start or i >= layer_end:
model._modules['transformer'].encoder.layers[i] = Dummy_GLMBlock()
@ -296,13 +307,17 @@ def pipeline_parallel_generate(self,
_past_key_values = past_key_values_placeholder
else:
_past_key_values = outputs.past_key_values
elif self.config.model_type in ["baichuan", "chatglm"] and local_rank != 0:
# for baichuan2 and chatglm3
value_placeholder = torch.empty_like((outputs.past_key_values)[-1][0])
past_key_values_placeholder = tuple(
(value_placeholder, value_placeholder) for _ in range(layer_start)
) + (outputs.past_key_values)[layer_start:]
_past_key_values = past_key_values_placeholder
elif self.config.model_type in ["baichuan", "chatglm"] or \
(self.config.model_type == "qwen" and hasattr(self.config, "visual")):
# for baichuan2, chatglm3, Qwen-VL-Chat
if local_rank != 0:
value_placeholder = torch.empty_like((outputs.past_key_values)[-1][0])
past_key_values_placeholder = tuple(
(value_placeholder, value_placeholder) for _ in range(layer_start)
) + (outputs.past_key_values)[layer_start:]
_past_key_values = past_key_values_placeholder
else:
_past_key_values = outputs.past_key_values
else:
_past_key_values = outputs.past_key_values