optimize phi3-v encoder npu performance and add multimodal example (#11553)
* phi3-v * readme
This commit is contained in:
		
							parent
							
								
									70ab1a6f1a
								
							
						
					
					
						commit
						105e124752
					
				
					 4 changed files with 370 additions and 0 deletions
				
			
		| 
						 | 
					@ -0,0 +1,75 @@
 | 
				
			||||||
 | 
					# Run Large Multimodal Model on Intel NPU
 | 
				
			||||||
 | 
					In this directory, you will find examples on how you could apply IPEX-LLM INT4 or INT8 optimizations on Large Multimodal Models on [Intel NPUs](../../../README.md). In this directory, you will find examples on how you could apply IPEX-LLM INT4 or INT8 optimizations on Large Multimodal Models on Intel NPUs. See the table blow for verified models.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## Verified Models
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					| Model      | Model Link                                                    |
 | 
				
			||||||
 | 
					|------------|----------------------------------------------------------------|
 | 
				
			||||||
 | 
					| Phi-3-Vision | [microsoft/Phi-3-vision-128k-instruct](https://huggingface.co/microsoft/Phi-3-vision-128k-instruct) |
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## 0. Requirements
 | 
				
			||||||
 | 
					To run these examples with IPEX-LLM on Intel NPUs, make sure to install the newest driver version of Intel NPU.
 | 
				
			||||||
 | 
					Go to https://www.intel.com/content/www/us/en/download/794734/intel-npu-driver-windows.html to download and unzip the driver.
 | 
				
			||||||
 | 
					Then go to **Device Manager**, find **Neural Processors** -> **Intel(R) AI Boost**.
 | 
				
			||||||
 | 
					Right click and select **Update Driver**. And then manually select the folder unzipped from the driver.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## Example: Predict Tokens using `generate()` API
 | 
				
			||||||
 | 
					In the example [generate.py](./generate.py), we show a basic use case for a phi-3-vision model to predict the next N tokens using `generate()` API, with IPEX-LLM INT4 optimizations on Intel NPUs.
 | 
				
			||||||
 | 
					### 1. Install
 | 
				
			||||||
 | 
					#### 1.1 Installation on Windows
 | 
				
			||||||
 | 
					We suggest using conda to manage environment:
 | 
				
			||||||
 | 
					```bash
 | 
				
			||||||
 | 
					conda create -n llm python=3.10 libuv
 | 
				
			||||||
 | 
					conda activate llm
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# below command will install intel_extension_for_pytorch==2.1.10+xpu as default
 | 
				
			||||||
 | 
					pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# below command will install intel_npu_acceleration_library
 | 
				
			||||||
 | 
					pip install intel-npu-acceleration-library==1.3
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					pip install transformers==4.40
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					### 2. Runtime Configurations
 | 
				
			||||||
 | 
					For optimal performance, it is recommended to set several environment variables. Please check out the suggestions based on your device.
 | 
				
			||||||
 | 
					#### 2.1 Configurations for Windows
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					**Following envrionment variables are required**:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					```cmd
 | 
				
			||||||
 | 
					set BIGDL_USE_NPU=1
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					### 3. Running examples
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					python ./generate.py
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Arguments info:
 | 
				
			||||||
 | 
					- `--repo-id-or-model-path REPO_ID_OR_MODEL_PATH`: argument defining the huggingface repo id for the Phi-3-vision model (e.g. `microsoft/Phi-3-vision-128k-instruct`) to be downloaded, or the path to the huggingface checkpoint folder. It is default to be `'microsoft/Phi-3-vision-128k-instruct'`, and more verified models please see the list in [Verified Models](#verified-models).
 | 
				
			||||||
 | 
					- `--image-url-or-path IMAGE_URL_OR_PATH`: argument defining the image to be infered. It is default to be `'http://farm6.staticflickr.com/5268/5602445367_3504763978_z.jpg'`.
 | 
				
			||||||
 | 
					- `--prompt PROMPT`: argument defining the prompt to be infered (with integrated prompt format for chat). It is default to be `'What is in the image?'`.
 | 
				
			||||||
 | 
					- `--n-predict N_PREDICT`: argument defining the max number of tokens to predict. It is default to be `32`.
 | 
				
			||||||
 | 
					- `--load_in_low_bit`: argument defining the `load_in_low_bit` format used. It is default to be `sym_int8`, `sym_int4` can also be used.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#### Sample Output
 | 
				
			||||||
 | 
					#### [microsoft/Phi-3-vision-128k-instruct](https://huggingface.co/microsoft/Phi-3-vision-128k-instruct)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					```log
 | 
				
			||||||
 | 
					Inference time: xxxx s
 | 
				
			||||||
 | 
					-------------------- Prompt --------------------
 | 
				
			||||||
 | 
					Message: [{'role': 'user', 'content': '<|image_1|>\nWhat is in the image?'}]
 | 
				
			||||||
 | 
					Image link/path: http://farm6.staticflickr.com/5268/5602445367_3504763978_z.jpg
 | 
				
			||||||
 | 
					-------------------- Output --------------------
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					What is in the image?
 | 
				
			||||||
 | 
					 The image shows a young girl holding a white teddy bear. She is wearing a pink dress with a heart on it. The background includes a stone
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					The sample input image is (which is fetched from [COCO dataset](https://cocodataset.org/#explore?id=264959)):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					<a href="http://farm6.staticflickr.com/5268/5602445367_3504763978_z.jpg"><img width=400px src="http://farm6.staticflickr.com/5268/5602445367_3504763978_z.jpg" ></a>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,93 @@
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Copyright 2016 The BigDL Authors.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
 | 
					import time
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					import argparse
 | 
				
			||||||
 | 
					import requests
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from PIL import Image
 | 
				
			||||||
 | 
					from ipex_llm.transformers.npu_model import AutoModelForCausalLM
 | 
				
			||||||
 | 
					from transformers import AutoProcessor
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if __name__ == '__main__':
 | 
				
			||||||
 | 
					    parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for phi-3 model')
 | 
				
			||||||
 | 
					    parser.add_argument('--repo-id-or-model-path', type=str, default="microsoft/Phi-3-vision-128k-instruct",
 | 
				
			||||||
 | 
					                        help='The huggingface repo id for the phi-3-vision model to be downloaded'
 | 
				
			||||||
 | 
					                             ', or the path to the huggingface checkpoint folder')
 | 
				
			||||||
 | 
					    parser.add_argument('--image-url-or-path', type=str,
 | 
				
			||||||
 | 
					                        default="http://farm6.staticflickr.com/5268/5602445367_3504763978_z.jpg",
 | 
				
			||||||
 | 
					                        help='The URL or path to the image to infer')
 | 
				
			||||||
 | 
					    parser.add_argument('--prompt', type=str, default="What is in the image?",
 | 
				
			||||||
 | 
					                        help='Prompt to infer')
 | 
				
			||||||
 | 
					    parser.add_argument('--n-predict', type=int, default=32,
 | 
				
			||||||
 | 
					                        help='Max tokens to predict')
 | 
				
			||||||
 | 
					    parser.add_argument('--load_in_low_bit', type=str, default="sym_int4",
 | 
				
			||||||
 | 
					                        help='Load in low bit to use')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    args = parser.parse_args()
 | 
				
			||||||
 | 
					    model_path = args.repo_id_or_model_path
 | 
				
			||||||
 | 
					    image_path = args.image_url_or_path
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Load model in SYM_INT4,
 | 
				
			||||||
 | 
					    # which convert the relevant layers in the model into SYM_INT4 format
 | 
				
			||||||
 | 
					    # You could also try `'sym_int8'` for INT8
 | 
				
			||||||
 | 
					    # `_attn_implementation="eager"` is required for phi-3-vision
 | 
				
			||||||
 | 
					    # `modules_to_not_convert=["vision_embed_tokens"]` and `model = model.half()` are for acceleration and are optional
 | 
				
			||||||
 | 
					    model = AutoModelForCausalLM.from_pretrained(model_path,
 | 
				
			||||||
 | 
					                                                 trust_remote_code=True,
 | 
				
			||||||
 | 
					                                                 load_in_low_bit=args.load_in_low_bit,
 | 
				
			||||||
 | 
					                                                 _attn_implementation="eager",
 | 
				
			||||||
 | 
					                                                 modules_to_not_convert=["vision_embed_tokens"])
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    # Load processor
 | 
				
			||||||
 | 
					    processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    # here the message formatting refers to https://huggingface.co/microsoft/Phi-3-vision-128k-instruct#sample-inference-code
 | 
				
			||||||
 | 
					    messages = [
 | 
				
			||||||
 | 
					        {"role": "user", "content": "<|image_1|>\n{prompt}".format(prompt=args.prompt)},
 | 
				
			||||||
 | 
					    ]
 | 
				
			||||||
 | 
					    prompt = processor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if os.path.exists(image_path):
 | 
				
			||||||
 | 
					       image = Image.open(image_path)
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					       image = Image.open(requests.get(image_path, stream=True).raw)
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    # Generate predicted tokens
 | 
				
			||||||
 | 
					    with torch.inference_mode():
 | 
				
			||||||
 | 
					        # start inference
 | 
				
			||||||
 | 
					        st = time.time()
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        inputs = processor(prompt, [image], return_tensors="pt")
 | 
				
			||||||
 | 
					        output = model.generate(**inputs,
 | 
				
			||||||
 | 
					                                eos_token_id=processor.tokenizer.eos_token_id,
 | 
				
			||||||
 | 
					                                num_beams=1,
 | 
				
			||||||
 | 
					                                do_sample=False,
 | 
				
			||||||
 | 
					                                max_new_tokens=args.n_predict,
 | 
				
			||||||
 | 
					                                temperature=0.0)
 | 
				
			||||||
 | 
					        end = time.time()
 | 
				
			||||||
 | 
					        print(f'Inference time: {end-st} s')
 | 
				
			||||||
 | 
					        output_str = processor.decode(output[0],
 | 
				
			||||||
 | 
					                                      skip_special_tokens=True,
 | 
				
			||||||
 | 
					                                      clean_up_tokenization_spaces=False)
 | 
				
			||||||
 | 
					        print('-'*20, 'Prompt', '-'*20)
 | 
				
			||||||
 | 
					        print(f'Message: {messages}')
 | 
				
			||||||
 | 
					        print(f'Image link/path: {image_path}')
 | 
				
			||||||
 | 
					        print('-'*20, 'Output', '-'*20)
 | 
				
			||||||
 | 
					        print(output_str)
 | 
				
			||||||
| 
						 | 
					@ -177,3 +177,15 @@ def optimize_llm(model: torch.nn.Module):
 | 
				
			||||||
        model.apply(merge_mlp)
 | 
					        model.apply(merge_mlp)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        convert_forward(model, module.MLP, baichuan_mlp_forward)
 | 
					        convert_forward(model, module.MLP, baichuan_mlp_forward)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    elif model.config.model_type == "phi3_v":
 | 
				
			||||||
 | 
					        modeling_module_name = model.__class__.__module__
 | 
				
			||||||
 | 
					        module = importlib.import_module(modeling_module_name)
 | 
				
			||||||
 | 
					        from ipex_llm.transformers.npu_models.phi3_v import merge_qkv
 | 
				
			||||||
 | 
					        from ipex_llm.transformers.npu_models.phi3_v import phi3v_encoder_attention_forward
 | 
				
			||||||
 | 
					        from ipex_llm.transformers.npu_models.phi3_v import phi3v_model_forward
 | 
				
			||||||
 | 
					        model.apply(merge_qkv)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        from transformers.models.clip.modeling_clip import CLIPAttention
 | 
				
			||||||
 | 
					        convert_forward(model, CLIPAttention, phi3v_encoder_attention_forward)
 | 
				
			||||||
 | 
					        convert_forward(model, module.Phi3VModel, phi3v_model_forward)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										190
									
								
								python/llm/src/ipex_llm/transformers/npu_models/phi3_v.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										190
									
								
								python/llm/src/ipex_llm/transformers/npu_models/phi3_v.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,190 @@
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Copyright 2016 The BigDL Authors.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Some parts of this file is adapted from
 | 
				
			||||||
 | 
					# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py
 | 
				
			||||||
 | 
					# which is licensed under Apache License 2.0:
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					import importlib
 | 
				
			||||||
 | 
					from torch import nn
 | 
				
			||||||
 | 
					from typing import Optional, Tuple, List
 | 
				
			||||||
 | 
					from transformers.models.clip.modeling_clip import CLIPAttention
 | 
				
			||||||
 | 
					from ipex_llm.utils.common.log4Error import invalidInputError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def merge_qkv(module: torch.nn.Module):
 | 
				
			||||||
 | 
					    if isinstance(module, CLIPAttention):
 | 
				
			||||||
 | 
					        new_weight = torch.cat([
 | 
				
			||||||
 | 
					            module.q_proj.weight.data,
 | 
				
			||||||
 | 
					            module.k_proj.weight.data,
 | 
				
			||||||
 | 
					            module.v_proj.weight.data,
 | 
				
			||||||
 | 
					        ], dim=0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if module.q_proj.bias is not None:
 | 
				
			||||||
 | 
					            qkv_proj = torch.nn.Linear(0, 0, bias=True)
 | 
				
			||||||
 | 
					            new_bias = torch.cat([
 | 
				
			||||||
 | 
					                module.q_proj.bias.data,
 | 
				
			||||||
 | 
					                module.k_proj.bias.data,
 | 
				
			||||||
 | 
					                module.v_proj.bias.data,
 | 
				
			||||||
 | 
					            ], dim=0)
 | 
				
			||||||
 | 
					            qkv_proj.bias = torch.nn.Parameter(new_bias, requires_grad=False)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            qkv_proj = torch.nn.Linear(0, 0, bias=False)
 | 
				
			||||||
 | 
					        qkv_proj.weight = torch.nn.Parameter(new_weight, requires_grad=False)
 | 
				
			||||||
 | 
					        qkv_proj.in_features = new_weight.size(1)
 | 
				
			||||||
 | 
					        qkv_proj.out_features = new_weight.size(0)
 | 
				
			||||||
 | 
					        module.qkv_proj = qkv_proj
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        del module.q_proj, module.k_proj, module.v_proj
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def phi3v_model_forward(
 | 
				
			||||||
 | 
					    self,
 | 
				
			||||||
 | 
					    input_ids: torch.LongTensor = None,
 | 
				
			||||||
 | 
					    attention_mask: Optional[torch.Tensor] = None,
 | 
				
			||||||
 | 
					    position_ids: Optional[torch.LongTensor] = None,
 | 
				
			||||||
 | 
					    past_key_values: Optional[List[torch.FloatTensor]] = None,
 | 
				
			||||||
 | 
					    inputs_embeds: Optional[torch.FloatTensor] = None,
 | 
				
			||||||
 | 
					    pixel_values: Optional[torch.FloatTensor] = None,
 | 
				
			||||||
 | 
					    image_sizes: Optional[torch.LongTensor] = None,
 | 
				
			||||||
 | 
					    use_cache: Optional[bool] = None,
 | 
				
			||||||
 | 
					    output_attentions: Optional[bool] = None,
 | 
				
			||||||
 | 
					    output_hidden_states: Optional[bool] = None,
 | 
				
			||||||
 | 
					    return_dict: Optional[bool] = None,
 | 
				
			||||||
 | 
					):
 | 
				
			||||||
 | 
					    # ipex-llm changes start
 | 
				
			||||||
 | 
					    from ipex_llm.transformers.kv import DynamicNormalCache
 | 
				
			||||||
 | 
					    # IPEX-LLM OPT: kv cache and quantize kv cache
 | 
				
			||||||
 | 
					    use_cache = use_cache if use_cache is not None else self.config.use_cache
 | 
				
			||||||
 | 
					    if use_cache:
 | 
				
			||||||
 | 
					        if not isinstance(past_key_values, DynamicNormalCache):
 | 
				
			||||||
 | 
					            past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values)
 | 
				
			||||||
 | 
					    modeling_module_name = self.__class__.__module__
 | 
				
			||||||
 | 
					    module = importlib.import_module(modeling_module_name)
 | 
				
			||||||
 | 
					    return module.Phi3VModel.forward(
 | 
				
			||||||
 | 
					        self=self,
 | 
				
			||||||
 | 
					        input_ids=input_ids,
 | 
				
			||||||
 | 
					        attention_mask=attention_mask,
 | 
				
			||||||
 | 
					        position_ids=position_ids,
 | 
				
			||||||
 | 
					        past_key_values=past_key_values,
 | 
				
			||||||
 | 
					        inputs_embeds=inputs_embeds,
 | 
				
			||||||
 | 
					        pixel_values=pixel_values,
 | 
				
			||||||
 | 
					        image_sizes=image_sizes,
 | 
				
			||||||
 | 
					        use_cache=use_cache,
 | 
				
			||||||
 | 
					        output_attentions=output_attentions,
 | 
				
			||||||
 | 
					        output_hidden_states=output_hidden_states,
 | 
				
			||||||
 | 
					        return_dict=return_dict,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def phi3v_encoder_attention_forward(
 | 
				
			||||||
 | 
					    self,
 | 
				
			||||||
 | 
					    hidden_states: torch.Tensor,
 | 
				
			||||||
 | 
					    attention_mask: Optional[torch.Tensor] = None,
 | 
				
			||||||
 | 
					    causal_attention_mask: Optional[torch.Tensor] = None,
 | 
				
			||||||
 | 
					    output_attentions: Optional[bool] = False,
 | 
				
			||||||
 | 
					) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
 | 
				
			||||||
 | 
					    bsz, tgt_len, embed_dim = hidden_states.size()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    qkv = self.qkv_proj(hidden_states)
 | 
				
			||||||
 | 
					    qkv = qkv.view(bsz, tgt_len, self.num_heads * 3, self.head_dim)
 | 
				
			||||||
 | 
					    qkv = qkv.transpose(1, 2)
 | 
				
			||||||
 | 
					    query_states, key_states, value_states = qkv.split([self.num_heads,
 | 
				
			||||||
 | 
					                                                        self.num_heads,
 | 
				
			||||||
 | 
					                                                        self.num_heads], dim=1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    proj_shape = (bsz * self.num_heads, -1, self.head_dim)
 | 
				
			||||||
 | 
					    query_states = query_states.reshape(*proj_shape)
 | 
				
			||||||
 | 
					    key_states = key_states.reshape(*proj_shape)
 | 
				
			||||||
 | 
					    value_states = value_states.reshape(*proj_shape)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    src_len = key_states.size(1)
 | 
				
			||||||
 | 
					    attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
 | 
				
			||||||
 | 
					        invalidInputError(
 | 
				
			||||||
 | 
					            False,
 | 
				
			||||||
 | 
					            f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)},"
 | 
				
			||||||
 | 
					            f" but is {attn_weights.size()}"
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # apply the causal_attention_mask first
 | 
				
			||||||
 | 
					    if causal_attention_mask is not None:
 | 
				
			||||||
 | 
					        if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):
 | 
				
			||||||
 | 
					            invalidInputError(
 | 
				
			||||||
 | 
					                False,
 | 
				
			||||||
 | 
					                f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
 | 
				
			||||||
 | 
					                f" {causal_attention_mask.size()}"
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) \
 | 
				
			||||||
 | 
					            + causal_attention_mask
 | 
				
			||||||
 | 
					        attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if attention_mask is not None:
 | 
				
			||||||
 | 
					        if attention_mask.size() != (bsz, 1, tgt_len, src_len):
 | 
				
			||||||
 | 
					            invalidInputError(
 | 
				
			||||||
 | 
					                False,
 | 
				
			||||||
 | 
					                f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)},"
 | 
				
			||||||
 | 
					                f" but is {attention_mask.size()}"
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
 | 
				
			||||||
 | 
					        attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    attn_weights = nn.functional.softmax(attn_weights, dim=-1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if output_attentions:
 | 
				
			||||||
 | 
					        # this operation is a bit akward, but it's required to
 | 
				
			||||||
 | 
					        # make sure that attn_weights keeps its gradient.
 | 
				
			||||||
 | 
					        # In order to do so, attn_weights have to reshaped
 | 
				
			||||||
 | 
					        # twice and have to be reused in the following
 | 
				
			||||||
 | 
					        attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
 | 
				
			||||||
 | 
					        attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        attn_weights_reshaped = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    attn_output = torch.bmm(attn_probs, value_states)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
 | 
				
			||||||
 | 
					        invalidInputError(
 | 
				
			||||||
 | 
					            False,
 | 
				
			||||||
 | 
					            f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)},"
 | 
				
			||||||
 | 
					            f" but is {attn_output.size()}"
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
 | 
				
			||||||
 | 
					    attn_output = attn_output.transpose(1, 2)
 | 
				
			||||||
 | 
					    attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    attn_output = self.out_proj(attn_output)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return attn_output, attn_weights_reshaped
 | 
				
			||||||
		Loading…
	
		Reference in a new issue