From 2b8ad8731ee9815cadaddd89e8818fd8510ef07d Mon Sep 17 00:00:00 2001
From: binbin Deng <108676127+plusbang@users.noreply.github.com>
Date: Thu, 11 Jul 2024 16:06:06 +0800
Subject: [PATCH] Support pipeline parallel for glm-4v (#11545)
---
.../GPU/Pipeline-Parallel-Inference/README.md | 15 ++++
.../glm_4v_generate.py | 87 +++++++++++++++++++
.../run_glm_4v_arc_2_card.sh | 31 +++++++
.../ipex_llm/transformers/models/chatglm4v.py | 11 +--
.../transformers/pipeline_parallel.py | 53 ++++++++---
5 files changed, 179 insertions(+), 18 deletions(-)
create mode 100644 python/llm/example/GPU/Pipeline-Parallel-Inference/glm_4v_generate.py
create mode 100644 python/llm/example/GPU/Pipeline-Parallel-Inference/run_glm_4v_arc_2_card.sh
diff --git a/python/llm/example/GPU/Pipeline-Parallel-Inference/README.md b/python/llm/example/GPU/Pipeline-Parallel-Inference/README.md
index cb9df2d0..c350be36 100644
--- a/python/llm/example/GPU/Pipeline-Parallel-Inference/README.md
+++ b/python/llm/example/GPU/Pipeline-Parallel-Inference/README.md
@@ -17,6 +17,7 @@ To run this example with IPEX-LLM on Intel GPUs, we have some recommended requir
- [Qwen/Qwen-VL-Chat](./run_qwen_vl_arc_2_card.sh)
- [Qwen/CodeQwen1.5-7B-Chat](./run_qwen1.5_arc_2_card.sh)
- [THUDM/glm-4-9b-chat](./run_chatglm_arc_2_card.sh)
+- [THUDM/glm-4v-9b](./run_glm_4v_arc_2_card.sh)
- [THUDM/chatglm3-6b](./run_chatglm_arc_2_card.sh)
- [baichuan-inc/Baichuan2-7B-Chat](./run_baichuan2_arc_2_card.sh)
- [baichuan-inc/Baichuan2-13B-Chat](./run_baichuan2_arc_2_card.sh)
@@ -145,6 +146,20 @@ bash run_chatglm_arc_2_card.sh
+
+ Show glm-4v example
+
+#### Run glm-4v-9b on two Intel Arc A770
+
+You could specify `--repo-id-or-model-path` in the test script to be the huggingface repo id for glm-4v-9b to be downloaded, or the path to the huggingface checkpoint folder. Besides, you could change `NUM_GPUS` to the number of GPUs you have on your machine.
+
+```bash
+pip install transformers==4.37.0 tiktoken
+bash run_glm_4v_arc_2_card.sh
+```
+
+
+
diff --git a/python/llm/example/GPU/Pipeline-Parallel-Inference/glm_4v_generate.py b/python/llm/example/GPU/Pipeline-Parallel-Inference/glm_4v_generate.py
new file mode 100644
index 00000000..f788a362
--- /dev/null
+++ b/python/llm/example/GPU/Pipeline-Parallel-Inference/glm_4v_generate.py
@@ -0,0 +1,87 @@
+#
+# Copyright 2016 The BigDL Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import time
+import torch
+import argparse
+import requests
+
+from PIL import Image
+from ipex_llm.transformers import AutoModelForCausalLM, init_pipeline_parallel
+from transformers import AutoTokenizer
+
+init_pipeline_parallel()
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for THUDM/glm-4v-9b model')
+ parser.add_argument('--repo-id-or-model-path', type=str, default="THUDM/glm-4v-9b",
+ help='The huggingface repo id for the THUDM/glm-4v-9b model to be downloaded'
+ ', or the path to the huggingface checkpoint folder')
+ parser.add_argument('--image-url-or-path', type=str,
+ default='http://farm6.staticflickr.com/5268/5602445367_3504763978_z.jpg',
+ help='The URL or path to the image to infer')
+ parser.add_argument('--prompt', type=str, default="这是什么?",
+ help='Prompt to infer')
+ parser.add_argument('--n-predict', type=int, default=32,
+ help='Max tokens to predict')
+ parser.add_argument('--low-bit', type=str, default='sym_int4', help='The quantization type the model will convert to.')
+ parser.add_argument('--gpu-num', type=int, default=2, help='GPU number to use')
+
+ args = parser.parse_args()
+ model_path = args.repo_id_or_model_path
+ image_path = args.image_url_or_path
+
+ model = AutoModelForCausalLM.from_pretrained(model_path,
+ load_in_low_bit=args.low_bit,
+ optimize_model=True,
+ trust_remote_code=True,
+ use_cache=True,
+ pipeline_parallel_stages=args.gpu_num)
+ model = model.half()
+
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
+ local_rank = torch.distributed.get_rank()
+
+ query = args.prompt
+ if os.path.exists(image_path):
+ image = Image.open(image_path)
+ else:
+ image = Image.open(requests.get(image_path, stream=True).raw)
+
+ # here the prompt tuning refers to https://huggingface.co/THUDM/glm-4v-9b/blob/main/README.md
+ inputs = tokenizer.apply_chat_template([{"role": "user", "image": image, "content": query}],
+ add_generation_prompt=True,
+ tokenize=True,
+ return_tensors="pt",
+ return_dict=True) # chat mode
+ inputs = inputs.to(f'xpu:{local_rank}')
+ all_input = [{'image': image_path}, {'text': query}]
+
+ # Generate predicted tokens
+ with torch.inference_mode():
+ gen_kwargs = {"max_new_tokens": args.n_predict, "do_sample": False,}
+ st = time.time()
+ outputs = model.generate(**inputs, **gen_kwargs)
+ outputs = outputs[:, inputs['input_ids'].shape[1]:]
+ end = time.time()
+ if local_rank == args.gpu_num - 1:
+ print(f'Inference time: {end-st} s')
+ output_str = tokenizer.decode(outputs[0])
+ print('-'*20, 'Input', '-'*20)
+ print(f'Message: {all_input}')
+ print('-'*20, 'Output', '-'*20)
+ print(output_str)
diff --git a/python/llm/example/GPU/Pipeline-Parallel-Inference/run_glm_4v_arc_2_card.sh b/python/llm/example/GPU/Pipeline-Parallel-Inference/run_glm_4v_arc_2_card.sh
new file mode 100644
index 00000000..98e1fa48
--- /dev/null
+++ b/python/llm/example/GPU/Pipeline-Parallel-Inference/run_glm_4v_arc_2_card.sh
@@ -0,0 +1,31 @@
+#
+# Copyright 2016 The BigDL Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+source /opt/intel/oneapi/setvars.sh
+export MASTER_ADDR=127.0.0.1
+export MASTER_PORT=9090
+export FI_PROVIDER=tcp
+export USE_XETLA=OFF
+export OMP_NUM_THREADS=6
+if [[ $KERNEL_VERSION != *"6.5"* ]]; then
+ export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1
+fi
+export TORCH_LLM_ALLREDUCE=0
+
+NUM_GPUS=2 # number of used GPU
+# To run glm-4v-9b
+CCL_ZE_IPC_EXCHANGE=sockets torchrun --standalone --nnodes=1 --nproc-per-node $NUM_GPUS \
+ glm_4v_generate.py --repo-id-or-model-path 'THUDM/glm-4v-9b' --gpu-num $NUM_GPUS --low-bit 'sym_int4'
diff --git a/python/llm/src/ipex_llm/transformers/models/chatglm4v.py b/python/llm/src/ipex_llm/transformers/models/chatglm4v.py
index 2b848e1b..a315124b 100644
--- a/python/llm/src/ipex_llm/transformers/models/chatglm4v.py
+++ b/python/llm/src/ipex_llm/transformers/models/chatglm4v.py
@@ -55,9 +55,7 @@ def chatglm4v_model_forward(
# generate mode with past_key_values. the image features are already mapped
if past_key_values is None:
# not allow for inputs_embeds, because we want to process image feature
- invalidInputError(input_ids is not None and inputs_embeds is None,
- f"{input_ids} should not be None, {inputs_embeds} should be None.")
- if not is_empty(images): # multi-modality
+ if not is_empty(images) and input_ids is not None: # multi-modality
image_size: int = self.config.vision_config['image_size']
patch_size: int = self.config.vision_config['patch_size']
num_patches = (image_size // patch_size // 2) ** 2
@@ -99,10 +97,13 @@ def chatglm4v_model_forward(
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- batch_size, seq_length = input_ids.shape
-
if inputs_embeds is None:
+ batch_size, seq_length = input_ids.shape
inputs_embeds = self.embedding(input_ids)
+ else:
+ batch_size, seq_length, _ = inputs_embeds.shape
+ input_ids = torch.empty((batch_size, seq_length),
+ dtype=inputs_embeds.dtype, device=inputs_embeds.device)
if full_attention_mask is None:
if (attention_mask is not None and not attention_mask.all()) or\
diff --git a/python/llm/src/ipex_llm/transformers/pipeline_parallel.py b/python/llm/src/ipex_llm/transformers/pipeline_parallel.py
index 8ff87af2..91d23797 100644
--- a/python/llm/src/ipex_llm/transformers/pipeline_parallel.py
+++ b/python/llm/src/ipex_llm/transformers/pipeline_parallel.py
@@ -229,13 +229,14 @@ def generate(
generation_config.pad_token_id = eos_token_id
if generation_config is not None and generation_config.max_new_tokens is not None:
- max_new_tokens = generation_config.max_new_tokens
+ max_new_tokens = generation_config.pop("max_new_tokens")
else:
- max_new_tokens = kwargs.get("max_new_tokens", None)
+ max_new_tokens = kwargs.pop("max_new_tokens", None)
return self.pipeline_parallel_generate(inputs=inputs,
max_new_tokens=max_new_tokens,
- generation_config=generation_config,)
+ generation_config=generation_config,
+ **kwargs)
return original_generate(self,
inputs=inputs,
@@ -257,6 +258,23 @@ def pipeline_parallel_generate(self,
max_new_tokens: int = 32,
generation_config: Optional[GenerationConfig] = None,
**kwargs):
+ model_kwargs = generation_config.update(**kwargs)
+ inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
+ inputs, generation_config.bos_token_id, model_kwargs
+ )
+ bs = inputs_tensor.shape[0]
+ if self.config.is_encoder_decoder:
+ input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
+ batch_size=bs,
+ model_input_name=model_input_name,
+ model_kwargs=model_kwargs,
+ decoder_start_token_id=generation_config.decoder_start_token_id,
+ bos_token_id=generation_config.bos_token_id,
+ device=inputs_tensor.device,
+ )
+ else:
+ input_ids = inputs_tensor if model_input_name == "input_ids" \
+ else model_kwargs.pop("input_ids")
local_rank = dist.get_rank()
pre_rank = (local_rank - 1) % self.pipeline_parallel_stages
next_rank = (local_rank + 1) % self.pipeline_parallel_stages
@@ -272,36 +290,44 @@ def pipeline_parallel_generate(self,
eos_token_id = generation_config.eos_token_id
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
- eos_token_id_tensor = torch.tensor(eos_token_id).to(inputs.device) \
+ eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) \
if eos_token_id is not None else None
_input_ids = None
_past_key_values = None
- bs = inputs.shape[0]
- output_ids = inputs.clone()
+
+ bs = input_ids.shape[0]
+ output_ids = input_ids.clone()
_check_quantize_kv_cache(self, layer_start, bs)
step = 0
# keep track of which sequences are already finished
- unfinished_sequences = torch.ones(inputs.shape[0], dtype=torch.long, device=inputs.device)
+ unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
this_peer_finished = False
while True:
if step >= max_new_tokens:
break
if _input_ids is None:
- _input_ids = inputs
+ _input_ids = input_ids
tic = time.time()
if local_rank == 0:
outputs = self(input_ids=_input_ids, inputs_embeds=None,
- past_key_values=_past_key_values, use_cache=True)
+ past_key_values=_past_key_values, use_cache=True, **model_kwargs)
else:
- inputs_embeds = torch.empty(_input_ids.shape + (self.config.hidden_size,),
+ _inputs_shape = _input_ids.shape + (self.config.hidden_size,)
+ if step == 0 and self.config.model_type == "chatglm" \
+ and hasattr(self.config, "vision_config"):
+ # for glm-4v, image features are mapped during 1st token
+ # 1597 are computed according to computation process of conv
+ _images_feature = 1597 + _input_ids.shape[0] * 2 + _input_ids.shape[1]
+ _inputs_shape = (_input_ids.shape[0], _images_feature, self.config.hidden_size,)
+ inputs_embeds = torch.empty(_inputs_shape,
device=f'xpu:{local_rank}', dtype=self.dtype)
dist.recv(inputs_embeds, src=pre_rank)
outputs = self(input_ids=None, inputs_embeds=inputs_embeds,
- past_key_values=_past_key_values, use_cache=True)
+ past_key_values=_past_key_values, use_cache=True, **model_kwargs)
if local_rank == self.pipeline_parallel_stages - 1:
logits = outputs.logits
@@ -323,7 +349,8 @@ def pipeline_parallel_generate(self,
"make sure that `pad_token_id` is defined.")
next_ids = next_ids * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
- if self.config.model_type == "chatglm" and self.config.num_layers == 40:
+ if self.config.model_type == "chatglm" and self.config.num_layers == 40 \
+ and not hasattr(self.config, "vision_config"):
# for glm-4-9b-chat
if step == 0:
value_placeholder = torch.empty_like((outputs.past_key_values)[-1][0])
@@ -337,7 +364,7 @@ def pipeline_parallel_generate(self,
_past_key_values = outputs.past_key_values
elif self.config.model_type in ["baichuan", "chatglm"] or \
(self.config.model_type == "qwen" and hasattr(self.config, "visual")):
- # for baichuan2, chatglm3, Qwen-VL-Chat
+ # for baichuan2, chatglm3, Qwen-VL-Chat, glm-4v-9b
if local_rank != 0:
value_placeholder = torch.empty_like((outputs.past_key_values)[-1][0])
past_key_values_placeholder = tuple(