optimize phi3-v encoder npu performance and add multimodal example (#11553)
* phi3-v * readme
This commit is contained in:
parent
70ab1a6f1a
commit
105e124752
4 changed files with 370 additions and 0 deletions
|
|
@ -0,0 +1,75 @@
|
||||||
|
# Run Large Multimodal Model on Intel NPU
|
||||||
|
In this directory, you will find examples on how you could apply IPEX-LLM INT4 or INT8 optimizations on Large Multimodal Models on [Intel NPUs](../../../README.md). In this directory, you will find examples on how you could apply IPEX-LLM INT4 or INT8 optimizations on Large Multimodal Models on Intel NPUs. See the table blow for verified models.
|
||||||
|
|
||||||
|
## Verified Models
|
||||||
|
|
||||||
|
| Model | Model Link |
|
||||||
|
|------------|----------------------------------------------------------------|
|
||||||
|
| Phi-3-Vision | [microsoft/Phi-3-vision-128k-instruct](https://huggingface.co/microsoft/Phi-3-vision-128k-instruct) |
|
||||||
|
|
||||||
|
## 0. Requirements
|
||||||
|
To run these examples with IPEX-LLM on Intel NPUs, make sure to install the newest driver version of Intel NPU.
|
||||||
|
Go to https://www.intel.com/content/www/us/en/download/794734/intel-npu-driver-windows.html to download and unzip the driver.
|
||||||
|
Then go to **Device Manager**, find **Neural Processors** -> **Intel(R) AI Boost**.
|
||||||
|
Right click and select **Update Driver**. And then manually select the folder unzipped from the driver.
|
||||||
|
|
||||||
|
## Example: Predict Tokens using `generate()` API
|
||||||
|
In the example [generate.py](./generate.py), we show a basic use case for a phi-3-vision model to predict the next N tokens using `generate()` API, with IPEX-LLM INT4 optimizations on Intel NPUs.
|
||||||
|
### 1. Install
|
||||||
|
#### 1.1 Installation on Windows
|
||||||
|
We suggest using conda to manage environment:
|
||||||
|
```bash
|
||||||
|
conda create -n llm python=3.10 libuv
|
||||||
|
conda activate llm
|
||||||
|
|
||||||
|
# below command will install intel_extension_for_pytorch==2.1.10+xpu as default
|
||||||
|
pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
|
||||||
|
|
||||||
|
# below command will install intel_npu_acceleration_library
|
||||||
|
pip install intel-npu-acceleration-library==1.3
|
||||||
|
|
||||||
|
pip install transformers==4.40
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Runtime Configurations
|
||||||
|
For optimal performance, it is recommended to set several environment variables. Please check out the suggestions based on your device.
|
||||||
|
#### 2.1 Configurations for Windows
|
||||||
|
|
||||||
|
**Following envrionment variables are required**:
|
||||||
|
|
||||||
|
```cmd
|
||||||
|
set BIGDL_USE_NPU=1
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Running examples
|
||||||
|
|
||||||
|
```
|
||||||
|
python ./generate.py
|
||||||
|
```
|
||||||
|
|
||||||
|
Arguments info:
|
||||||
|
- `--repo-id-or-model-path REPO_ID_OR_MODEL_PATH`: argument defining the huggingface repo id for the Phi-3-vision model (e.g. `microsoft/Phi-3-vision-128k-instruct`) to be downloaded, or the path to the huggingface checkpoint folder. It is default to be `'microsoft/Phi-3-vision-128k-instruct'`, and more verified models please see the list in [Verified Models](#verified-models).
|
||||||
|
- `--image-url-or-path IMAGE_URL_OR_PATH`: argument defining the image to be infered. It is default to be `'http://farm6.staticflickr.com/5268/5602445367_3504763978_z.jpg'`.
|
||||||
|
- `--prompt PROMPT`: argument defining the prompt to be infered (with integrated prompt format for chat). It is default to be `'What is in the image?'`.
|
||||||
|
- `--n-predict N_PREDICT`: argument defining the max number of tokens to predict. It is default to be `32`.
|
||||||
|
- `--load_in_low_bit`: argument defining the `load_in_low_bit` format used. It is default to be `sym_int8`, `sym_int4` can also be used.
|
||||||
|
|
||||||
|
#### Sample Output
|
||||||
|
#### [microsoft/Phi-3-vision-128k-instruct](https://huggingface.co/microsoft/Phi-3-vision-128k-instruct)
|
||||||
|
|
||||||
|
```log
|
||||||
|
Inference time: xxxx s
|
||||||
|
-------------------- Prompt --------------------
|
||||||
|
Message: [{'role': 'user', 'content': '<|image_1|>\nWhat is in the image?'}]
|
||||||
|
Image link/path: http://farm6.staticflickr.com/5268/5602445367_3504763978_z.jpg
|
||||||
|
-------------------- Output --------------------
|
||||||
|
|
||||||
|
|
||||||
|
What is in the image?
|
||||||
|
The image shows a young girl holding a white teddy bear. She is wearing a pink dress with a heart on it. The background includes a stone
|
||||||
|
```
|
||||||
|
|
||||||
|
The sample input image is (which is fetched from [COCO dataset](https://cocodataset.org/#explore?id=264959)):
|
||||||
|
|
||||||
|
<a href="http://farm6.staticflickr.com/5268/5602445367_3504763978_z.jpg"><img width=400px src="http://farm6.staticflickr.com/5268/5602445367_3504763978_z.jpg" ></a>
|
||||||
|
|
||||||
|
|
@ -0,0 +1,93 @@
|
||||||
|
#
|
||||||
|
# Copyright 2016 The BigDL Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import torch
|
||||||
|
import argparse
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
from ipex_llm.transformers.npu_model import AutoModelForCausalLM
|
||||||
|
from transformers import AutoProcessor
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for phi-3 model')
|
||||||
|
parser.add_argument('--repo-id-or-model-path', type=str, default="microsoft/Phi-3-vision-128k-instruct",
|
||||||
|
help='The huggingface repo id for the phi-3-vision model to be downloaded'
|
||||||
|
', or the path to the huggingface checkpoint folder')
|
||||||
|
parser.add_argument('--image-url-or-path', type=str,
|
||||||
|
default="http://farm6.staticflickr.com/5268/5602445367_3504763978_z.jpg",
|
||||||
|
help='The URL or path to the image to infer')
|
||||||
|
parser.add_argument('--prompt', type=str, default="What is in the image?",
|
||||||
|
help='Prompt to infer')
|
||||||
|
parser.add_argument('--n-predict', type=int, default=32,
|
||||||
|
help='Max tokens to predict')
|
||||||
|
parser.add_argument('--load_in_low_bit', type=str, default="sym_int4",
|
||||||
|
help='Load in low bit to use')
|
||||||
|
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
model_path = args.repo_id_or_model_path
|
||||||
|
image_path = args.image_url_or_path
|
||||||
|
|
||||||
|
# Load model in SYM_INT4,
|
||||||
|
# which convert the relevant layers in the model into SYM_INT4 format
|
||||||
|
# You could also try `'sym_int8'` for INT8
|
||||||
|
# `_attn_implementation="eager"` is required for phi-3-vision
|
||||||
|
# `modules_to_not_convert=["vision_embed_tokens"]` and `model = model.half()` are for acceleration and are optional
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(model_path,
|
||||||
|
trust_remote_code=True,
|
||||||
|
load_in_low_bit=args.load_in_low_bit,
|
||||||
|
_attn_implementation="eager",
|
||||||
|
modules_to_not_convert=["vision_embed_tokens"])
|
||||||
|
|
||||||
|
# Load processor
|
||||||
|
processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
|
||||||
|
|
||||||
|
# here the message formatting refers to https://huggingface.co/microsoft/Phi-3-vision-128k-instruct#sample-inference-code
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": "<|image_1|>\n{prompt}".format(prompt=args.prompt)},
|
||||||
|
]
|
||||||
|
prompt = processor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||||
|
|
||||||
|
if os.path.exists(image_path):
|
||||||
|
image = Image.open(image_path)
|
||||||
|
else:
|
||||||
|
image = Image.open(requests.get(image_path, stream=True).raw)
|
||||||
|
|
||||||
|
# Generate predicted tokens
|
||||||
|
with torch.inference_mode():
|
||||||
|
# start inference
|
||||||
|
st = time.time()
|
||||||
|
|
||||||
|
inputs = processor(prompt, [image], return_tensors="pt")
|
||||||
|
output = model.generate(**inputs,
|
||||||
|
eos_token_id=processor.tokenizer.eos_token_id,
|
||||||
|
num_beams=1,
|
||||||
|
do_sample=False,
|
||||||
|
max_new_tokens=args.n_predict,
|
||||||
|
temperature=0.0)
|
||||||
|
end = time.time()
|
||||||
|
print(f'Inference time: {end-st} s')
|
||||||
|
output_str = processor.decode(output[0],
|
||||||
|
skip_special_tokens=True,
|
||||||
|
clean_up_tokenization_spaces=False)
|
||||||
|
print('-'*20, 'Prompt', '-'*20)
|
||||||
|
print(f'Message: {messages}')
|
||||||
|
print(f'Image link/path: {image_path}')
|
||||||
|
print('-'*20, 'Output', '-'*20)
|
||||||
|
print(output_str)
|
||||||
|
|
@ -177,3 +177,15 @@ def optimize_llm(model: torch.nn.Module):
|
||||||
model.apply(merge_mlp)
|
model.apply(merge_mlp)
|
||||||
|
|
||||||
convert_forward(model, module.MLP, baichuan_mlp_forward)
|
convert_forward(model, module.MLP, baichuan_mlp_forward)
|
||||||
|
|
||||||
|
elif model.config.model_type == "phi3_v":
|
||||||
|
modeling_module_name = model.__class__.__module__
|
||||||
|
module = importlib.import_module(modeling_module_name)
|
||||||
|
from ipex_llm.transformers.npu_models.phi3_v import merge_qkv
|
||||||
|
from ipex_llm.transformers.npu_models.phi3_v import phi3v_encoder_attention_forward
|
||||||
|
from ipex_llm.transformers.npu_models.phi3_v import phi3v_model_forward
|
||||||
|
model.apply(merge_qkv)
|
||||||
|
|
||||||
|
from transformers.models.clip.modeling_clip import CLIPAttention
|
||||||
|
convert_forward(model, CLIPAttention, phi3v_encoder_attention_forward)
|
||||||
|
convert_forward(model, module.Phi3VModel, phi3v_model_forward)
|
||||||
|
|
|
||||||
190
python/llm/src/ipex_llm/transformers/npu_models/phi3_v.py
Normal file
190
python/llm/src/ipex_llm/transformers/npu_models/phi3_v.py
Normal file
|
|
@ -0,0 +1,190 @@
|
||||||
|
#
|
||||||
|
# Copyright 2016 The BigDL Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
# Some parts of this file is adapted from
|
||||||
|
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py
|
||||||
|
# which is licensed under Apache License 2.0:
|
||||||
|
#
|
||||||
|
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import importlib
|
||||||
|
from torch import nn
|
||||||
|
from typing import Optional, Tuple, List
|
||||||
|
from transformers.models.clip.modeling_clip import CLIPAttention
|
||||||
|
from ipex_llm.utils.common.log4Error import invalidInputError
|
||||||
|
|
||||||
|
|
||||||
|
def merge_qkv(module: torch.nn.Module):
|
||||||
|
if isinstance(module, CLIPAttention):
|
||||||
|
new_weight = torch.cat([
|
||||||
|
module.q_proj.weight.data,
|
||||||
|
module.k_proj.weight.data,
|
||||||
|
module.v_proj.weight.data,
|
||||||
|
], dim=0)
|
||||||
|
|
||||||
|
if module.q_proj.bias is not None:
|
||||||
|
qkv_proj = torch.nn.Linear(0, 0, bias=True)
|
||||||
|
new_bias = torch.cat([
|
||||||
|
module.q_proj.bias.data,
|
||||||
|
module.k_proj.bias.data,
|
||||||
|
module.v_proj.bias.data,
|
||||||
|
], dim=0)
|
||||||
|
qkv_proj.bias = torch.nn.Parameter(new_bias, requires_grad=False)
|
||||||
|
else:
|
||||||
|
qkv_proj = torch.nn.Linear(0, 0, bias=False)
|
||||||
|
qkv_proj.weight = torch.nn.Parameter(new_weight, requires_grad=False)
|
||||||
|
qkv_proj.in_features = new_weight.size(1)
|
||||||
|
qkv_proj.out_features = new_weight.size(0)
|
||||||
|
module.qkv_proj = qkv_proj
|
||||||
|
|
||||||
|
del module.q_proj, module.k_proj, module.v_proj
|
||||||
|
|
||||||
|
|
||||||
|
def phi3v_model_forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
pixel_values: Optional[torch.FloatTensor] = None,
|
||||||
|
image_sizes: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
):
|
||||||
|
# ipex-llm changes start
|
||||||
|
from ipex_llm.transformers.kv import DynamicNormalCache
|
||||||
|
# IPEX-LLM OPT: kv cache and quantize kv cache
|
||||||
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
|
if use_cache:
|
||||||
|
if not isinstance(past_key_values, DynamicNormalCache):
|
||||||
|
past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values)
|
||||||
|
modeling_module_name = self.__class__.__module__
|
||||||
|
module = importlib.import_module(modeling_module_name)
|
||||||
|
return module.Phi3VModel.forward(
|
||||||
|
self=self,
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
pixel_values=pixel_values,
|
||||||
|
image_sizes=image_sizes,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def phi3v_encoder_attention_forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
causal_attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
output_attentions: Optional[bool] = False,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
bsz, tgt_len, embed_dim = hidden_states.size()
|
||||||
|
|
||||||
|
qkv = self.qkv_proj(hidden_states)
|
||||||
|
qkv = qkv.view(bsz, tgt_len, self.num_heads * 3, self.head_dim)
|
||||||
|
qkv = qkv.transpose(1, 2)
|
||||||
|
query_states, key_states, value_states = qkv.split([self.num_heads,
|
||||||
|
self.num_heads,
|
||||||
|
self.num_heads], dim=1)
|
||||||
|
|
||||||
|
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
|
||||||
|
query_states = query_states.reshape(*proj_shape)
|
||||||
|
key_states = key_states.reshape(*proj_shape)
|
||||||
|
value_states = value_states.reshape(*proj_shape)
|
||||||
|
|
||||||
|
src_len = key_states.size(1)
|
||||||
|
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
||||||
|
|
||||||
|
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
|
||||||
|
invalidInputError(
|
||||||
|
False,
|
||||||
|
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)},"
|
||||||
|
f" but is {attn_weights.size()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# apply the causal_attention_mask first
|
||||||
|
if causal_attention_mask is not None:
|
||||||
|
if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
||||||
|
invalidInputError(
|
||||||
|
False,
|
||||||
|
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
|
||||||
|
f" {causal_attention_mask.size()}"
|
||||||
|
)
|
||||||
|
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) \
|
||||||
|
+ causal_attention_mask
|
||||||
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
||||||
|
invalidInputError(
|
||||||
|
False,
|
||||||
|
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)},"
|
||||||
|
f" but is {attention_mask.size()}"
|
||||||
|
)
|
||||||
|
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
||||||
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||||
|
|
||||||
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
# this operation is a bit akward, but it's required to
|
||||||
|
# make sure that attn_weights keeps its gradient.
|
||||||
|
# In order to do so, attn_weights have to reshaped
|
||||||
|
# twice and have to be reused in the following
|
||||||
|
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||||
|
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
|
||||||
|
else:
|
||||||
|
attn_weights_reshaped = None
|
||||||
|
|
||||||
|
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
||||||
|
|
||||||
|
attn_output = torch.bmm(attn_probs, value_states)
|
||||||
|
|
||||||
|
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
|
||||||
|
invalidInputError(
|
||||||
|
False,
|
||||||
|
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)},"
|
||||||
|
f" but is {attn_output.size()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
||||||
|
attn_output = attn_output.transpose(1, 2)
|
||||||
|
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
|
||||||
|
|
||||||
|
attn_output = self.out_proj(attn_output)
|
||||||
|
|
||||||
|
return attn_output, attn_weights_reshaped
|
||||||
Loading…
Reference in a new issue