[NPU] Add support for loading a FunASR model (#12073)
* add support for loading funasr model * add initial support for paraformer-encoder * add npu ops impl * add encoder-decoder npu pipeline * move paraformer encoders prefix 30 layers to npu and keep the rest layers on cpu
This commit is contained in:
		
							parent
							
								
									854398f6e0
								
							
						
					
					
						commit
						a0c6432899
					
				
					 7 changed files with 1510 additions and 42 deletions
				
			
		| 
						 | 
				
			
			@ -8,8 +8,9 @@ In this directory, you will find examples on how you could apply IPEX-LLM INT4 o
 | 
			
		|||
| Phi-3-Vision | [microsoft/Phi-3-vision-128k-instruct](https://huggingface.co/microsoft/Phi-3-vision-128k-instruct) |
 | 
			
		||||
| MiniCPM-Llama3-V-2_5 | [openbmb/MiniCPM-Llama3-V-2_5](https://huggingface.co/openbmb/MiniCPM-Llama3-V-2_5) |
 | 
			
		||||
| MiniCPM-V-2_6 | [openbmb/MiniCPM-V-2_6](https://huggingface.co/openbmb/MiniCPM-V-2_6) |
 | 
			
		||||
| Speech_Paraformer-Large | [iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch](https://www.modelscope.cn/models/iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch) |
 | 
			
		||||
 | 
			
		||||
## 0. Requirements
 | 
			
		||||
## Requirements
 | 
			
		||||
To run these examples with IPEX-LLM on Intel NPUs, make sure to install the newest driver version of Intel NPU.
 | 
			
		||||
Go to https://www.intel.com/content/www/us/en/download/794734/intel-npu-driver-windows.html to download and unzip the driver.
 | 
			
		||||
Then go to **Device Manager**, find **Neural Processors** -> **Intel(R) AI Boost**.
 | 
			
		||||
| 
						 | 
				
			
			@ -30,6 +31,10 @@ pip install torchvision
 | 
			
		|||
 | 
			
		||||
# [optional] for MiniCPM-V-2_6
 | 
			
		||||
pip install timm torch==2.1.2 torchvision==0.16.2
 | 
			
		||||
 | 
			
		||||
# [optional] for Speech_Paraformer-Large
 | 
			
		||||
pip install -U funasr
 | 
			
		||||
pip install modelscope torch==2.1.2 torchaudio==2.1.2
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
### 2. Runtime Configurations
 | 
			
		||||
| 
						 | 
				
			
			@ -64,6 +69,7 @@ Arguments info:
 | 
			
		|||
- `--n-predict N_PREDICT`: argument defining the max number of tokens to predict. It is default to be `32`.
 | 
			
		||||
- `--load_in_low_bit`: argument defining the `load_in_low_bit` format used. It is default to be `sym_int8`, `sym_int4` can also be used.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
#### Sample Output
 | 
			
		||||
##### [microsoft/Phi-3-vision-128k-instruct](https://huggingface.co/microsoft/Phi-3-vision-128k-instruct)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -84,11 +90,12 @@ The sample input image is (which is fetched from [COCO dataset](https://cocodata
 | 
			
		|||
<a href="http://farm6.staticflickr.com/5268/5602445367_3504763978_z.jpg"><img width=400px src="http://farm6.staticflickr.com/5268/5602445367_3504763978_z.jpg" ></a>
 | 
			
		||||
 | 
			
		||||
## 4. Run Optimized Models (Experimental)
 | 
			
		||||
The examples below show how to run the **_optimized HuggingFace model implementations_** on Intel NPU, including
 | 
			
		||||
The examples below show how to run the **_optimized HuggingFace & FunASR model implementations_** on Intel NPU, including
 | 
			
		||||
- [MiniCPM-Llama3-V-2_5](./minicpm-llama3-v2.5.py)
 | 
			
		||||
- [MiniCPM-V-2_6](./minicpm_v_2_6.py)
 | 
			
		||||
- [Speech_Paraformer-Large](./speech_paraformer-large.py)
 | 
			
		||||
 | 
			
		||||
### Run
 | 
			
		||||
### 4.1 Run MiniCPM-Llama3-V-2_5 & MiniCPM-V-2_6
 | 
			
		||||
```bash
 | 
			
		||||
# to run MiniCPM-Llama3-V-2_5
 | 
			
		||||
python minicpm-llama3-v2.5.py
 | 
			
		||||
| 
						 | 
				
			
			@ -117,4 +124,27 @@ http://farm6.staticflickr.com/5268/5602445367_3504763978_z.jpg
 | 
			
		|||
What is in this image?
 | 
			
		||||
-------------------- Output --------------------
 | 
			
		||||
The image features a young child holding and showing off a white teddy bear wearing a pink dress. The background includes some red flowers and a stone wall, suggesting an outdoor setting.
 | 
			
		||||
```
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
### 4.2 Run Speech_Paraformer-Large
 | 
			
		||||
```bash
 | 
			
		||||
# to run Speech_Paraformer-Large
 | 
			
		||||
python speech_paraformer-large.py
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
Arguments info:
 | 
			
		||||
- `--repo-id-or-model-path REPO_ID_OR_MODEL_PATH`: argument defining the asr repo id for the model (i.e. `iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch`) to be downloaded, or the path to the asr checkpoint folder.
 | 
			
		||||
- `--load_in_low_bit`: argument defining the `load_in_low_bit` format used. It is default to be `sym_int8`, `sym_int4` can also be used.
 | 
			
		||||
 | 
			
		||||
#### Sample Output
 | 
			
		||||
##### [iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch](https://www.modelscope.cn/models/iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch)
 | 
			
		||||
 | 
			
		||||
```log
 | 
			
		||||
# speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/example/asr_example.wav
 | 
			
		||||
rtf_avg: 0.090: 100%|███████████████████████████████████| 1/1 [00:01<00:00,  1.18s/it]
 | 
			
		||||
[{'key': 'asr_example', 'text': '正 是 因 为 存 在 绝 对 正 义 所 以 我 们 接 受 现 实 的 相 对 正 义 但 是 不 要 因 为 现 实 的 相 对 正 义 我 们 就 认 为 这 个 世 界 没 有 正 义 因 为 如 果 当 你 认 为 这 个 世 界 没 有 正 义'}]
 | 
			
		||||
 | 
			
		||||
# https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav
 | 
			
		||||
rtf_avg: 0.232: 100%|███████████████████████████████████| 1/1 [00:01<00:00,  1.29s/it]
 | 
			
		||||
[{'key': 'asr_example_zh', 'text': '欢 迎 大 家 来 体 验 达 摩 院 推 出 的 语 音 识 别 模 型'}]
 | 
			
		||||
```
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -0,0 +1,57 @@
 | 
			
		|||
#
 | 
			
		||||
# Copyright 2016 The BigDL Authors.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
#
 | 
			
		||||
 | 
			
		||||
import os
 | 
			
		||||
import torch
 | 
			
		||||
import time
 | 
			
		||||
import argparse
 | 
			
		||||
 | 
			
		||||
from ipex_llm.transformers.npu_model import FunAsrAutoModel as AutoModel
 | 
			
		||||
from transformers.utils import logging
 | 
			
		||||
 | 
			
		||||
logger = logging.get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    parser = argparse.ArgumentParser(
 | 
			
		||||
        description="Transcribe speech to text using `generate()` API for npu model"
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--repo-id-or-model-path",
 | 
			
		||||
        type=str,
 | 
			
		||||
        default="iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument('--load_in_low_bit', type=str, default="sym_int8",
 | 
			
		||||
                        help='Load in low bit to use')
 | 
			
		||||
    parser.add_argument("--intra-pp", type=int, default=2)
 | 
			
		||||
    parser.add_argument("--inter-pp", type=int, default=2)
 | 
			
		||||
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
    model_path = args.repo_id_or_model_path
 | 
			
		||||
 | 
			
		||||
    model = AutoModel(
 | 
			
		||||
        model=model_path,
 | 
			
		||||
        attn_implementation="eager",
 | 
			
		||||
        load_in_low_bit=args.load_in_low_bit,
 | 
			
		||||
        low_cpu_mem_usage=True,
 | 
			
		||||
        optimize_model=True,
 | 
			
		||||
        intra_pp=args.intra_pp,
 | 
			
		||||
        inter_pp=args.inter_pp,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    res = model.generate(input=f"{model.model_path}/example/asr_example.wav",
 | 
			
		||||
                         batch_size_s=300,
 | 
			
		||||
                         hotword='魔搭')
 | 
			
		||||
    print(res)
 | 
			
		||||
							
								
								
									
										26
									
								
								python/llm/example/NPU/HF-Transformers-AutoModels/README.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										26
									
								
								python/llm/example/NPU/HF-Transformers-AutoModels/README.md
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,26 @@
 | 
			
		|||
# IPEX-LLM Examples on Intel NPU
 | 
			
		||||
 | 
			
		||||
This folder contains examples of running IPEX-LLM on Intel NPU:
 | 
			
		||||
 | 
			
		||||
- [LLM](LLM): examples of running large language models using IPEX-LLM optimizations
 | 
			
		||||
- [Multimodal](Multimodal): examples of running large multimodal models using IPEX-LLM optimizations
 | 
			
		||||
 | 
			
		||||
## Verified Models on Intel NPU
 | 
			
		||||
| Model      | Model Link                                                    |
 | 
			
		||||
|------------|----------------------------------------------------------------|
 | 
			
		||||
| Llama2 | [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf) |
 | 
			
		||||
| Llama3 | [meta-llama/Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) |
 | 
			
		||||
| Chatglm3 | [THUDM/chatglm3-6b](https://huggingface.co/THUDM/chatglm3-6b) |
 | 
			
		||||
| Chatglm2 | [THUDM/chatglm2-6b](https://huggingface.co/THUDM/chatglm2-6b) |
 | 
			
		||||
| Qwen2 | [Qwen/Qwen2-7B-Instruct](https://huggingface.co/Qwen/Qwen2-7B-Instruct), [Qwen/Qwen2-1.5B-Instruct](https://huggingface.co/Qwen/Qwen2-1.5B-Instruct) |
 | 
			
		||||
| Qwen2.5 | [Qwen/Qwen2.5-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-7B-Instruct) |
 | 
			
		||||
| MiniCPM | [openbmb/MiniCPM-2B-sft-bf16](https://huggingface.co/openbmb/MiniCPM-2B-sft-bf16) |
 | 
			
		||||
| Phi-3 | [microsoft/Phi-3-mini-4k-instruct](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) |
 | 
			
		||||
| Stablelm | [stabilityai/stablelm-zephyr-3b](https://huggingface.co/stabilityai/stablelm-zephyr-3b) |
 | 
			
		||||
| Baichuan2 | [baichuan-inc/Baichuan2-7B-Chat](https://huggingface.co/baichuan-inc/Baichuan-7B-Chat) |
 | 
			
		||||
| Deepseek | [deepseek-ai/deepseek-coder-6.7b-instruct](https://huggingface.co/deepseek-ai/deepseek-coder-6.7b-instruct) |
 | 
			
		||||
| Mistral | [mistralai/Mistral-7B-Instruct-v0.1](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1) |
 | 
			
		||||
| Phi-3-Vision | [microsoft/Phi-3-vision-128k-instruct](https://huggingface.co/microsoft/Phi-3-vision-128k-instruct) |
 | 
			
		||||
| MiniCPM-Llama3-V-2_5 | [openbmb/MiniCPM-Llama3-V-2_5](https://huggingface.co/openbmb/MiniCPM-Llama3-V-2_5) |
 | 
			
		||||
| MiniCPM-V-2_6 | [openbmb/MiniCPM-V-2_6](https://huggingface.co/openbmb/MiniCPM-V-2_6) |
 | 
			
		||||
| Speech_Paraformer-Large | [iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch](https://www.modelscope.cn/models/iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch) |
 | 
			
		||||
| 
						 | 
				
			
			@ -141,17 +141,24 @@ class _BaseAutoModelClass:
 | 
			
		|||
 | 
			
		||||
        _args = copy.deepcopy(args)
 | 
			
		||||
        _kwargs = copy.deepcopy(kwargs)
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            # To handle the input CUDA setting (such as 'device_map={"":0}'), ignore it
 | 
			
		||||
            kwargs.pop("device_map", None)
 | 
			
		||||
            model = cls.HF_Model.from_pretrained(*args, **kwargs)
 | 
			
		||||
            if hasattr(cls.HF_Model, "from_pretrained"):
 | 
			
		||||
                model = cls.HF_Model.from_pretrained(*args, **kwargs)
 | 
			
		||||
            else:
 | 
			
		||||
                model = cls.HF_Model(*args, **kwargs)
 | 
			
		||||
        except NotImplementedError:
 | 
			
		||||
            logger.info(
 | 
			
		||||
                "Failed to load models with `low_cpu_mem_usage` specified, "
 | 
			
		||||
                "will fall to traditional load method with higher memory consumption."
 | 
			
		||||
            )
 | 
			
		||||
            _kwargs["low_cpu_mem_usage"] = False
 | 
			
		||||
            model = cls.HF_Model.from_pretrained(*_args, **_kwargs)
 | 
			
		||||
            if hasattr(cls.HF_Model, "from_pretrained"):
 | 
			
		||||
                model = cls.HF_Model.from_pretrained(*args, **kwargs)
 | 
			
		||||
            else:
 | 
			
		||||
                model = cls.HF_Model(*args, **kwargs)
 | 
			
		||||
            model.config.update({"bigdl_lcmu_enabled": False})
 | 
			
		||||
 | 
			
		||||
        logger.info(f"Converting model, it may takes up to several minutes ...")
 | 
			
		||||
| 
						 | 
				
			
			@ -165,42 +172,20 @@ class _BaseAutoModelClass:
 | 
			
		|||
                    " than max_output_len ({max_output_len})"
 | 
			
		||||
                ),
 | 
			
		||||
            )
 | 
			
		||||
            from ipex_llm.transformers.npu_models.convert_mp import optimize_llm, optimize_llm_pre
 | 
			
		||||
 | 
			
		||||
            if hasattr(model, "llm"):
 | 
			
		||||
                llm = model.llm
 | 
			
		||||
            else:
 | 
			
		||||
                llm = model
 | 
			
		||||
 | 
			
		||||
            with torch.no_grad():
 | 
			
		||||
                model.config.update({"mixed_precision": mixed_precision})
 | 
			
		||||
                model.config.update({"group_size": quantization_group_size})
 | 
			
		||||
                optimize_llm_pre(model, qtype, mixed_precision,
 | 
			
		||||
                                 quantization_group_size=quantization_group_size)
 | 
			
		||||
                cls.load_convert(qtype, model, "cpu", modules_to_not_convert,
 | 
			
		||||
                                 quantization_group_size, *args, **kwargs)
 | 
			
		||||
                create_npu_kernels(llm)
 | 
			
		||||
            model = model.eval()
 | 
			
		||||
            logger.info(f"Finish to convert model")
 | 
			
		||||
            model.config.update({"bigdl_transformers_low_bit": qtype})
 | 
			
		||||
            model.share_memory()
 | 
			
		||||
 | 
			
		||||
            if not pipeline:
 | 
			
		||||
                optimize_llm(
 | 
			
		||||
                    llm,
 | 
			
		||||
                    max_output_len=max_output_len,
 | 
			
		||||
                    max_prompt_len=max_prompt_len,
 | 
			
		||||
                    inter_pp=inter_pp,
 | 
			
		||||
                    intra_pp=intra_pp,
 | 
			
		||||
                    transpose_value_cache=transpose_value_cache,
 | 
			
		||||
                    group_size=quantization_group_size
 | 
			
		||||
                )
 | 
			
		||||
                model.save_low_bit = types.MethodType(save_low_bit, model)
 | 
			
		||||
            else:
 | 
			
		||||
                from ipex_llm.transformers.npu_pipeline_model.convert_pipeline import convert_llm
 | 
			
		||||
                convert_llm(llm,
 | 
			
		||||
                            kv_len=max_output_len,
 | 
			
		||||
                            transpose_value_cache=transpose_value_cache)
 | 
			
		||||
            optimize_kwargs = {
 | 
			
		||||
                "model": model,
 | 
			
		||||
                "qtype": qtype,
 | 
			
		||||
                "mixed_precision": mixed_precision,
 | 
			
		||||
                "quantization_group_size": quantization_group_size,
 | 
			
		||||
                "modules_to_not_convert": modules_to_not_convert,
 | 
			
		||||
                "pipeline": pipeline,
 | 
			
		||||
                "max_output_len": max_output_len,
 | 
			
		||||
                "max_prompt_len": max_prompt_len,
 | 
			
		||||
                "inter_pp": inter_pp,
 | 
			
		||||
                "intra_pp": intra_pp,
 | 
			
		||||
                "transpose_value_cache": transpose_value_cache,
 | 
			
		||||
            }
 | 
			
		||||
            model = cls.optimize_npu_model(*args, **optimize_kwargs)
 | 
			
		||||
        else:
 | 
			
		||||
            from ipex_llm.transformers.npu_models.convert import optimize_llm
 | 
			
		||||
            optimize_llm(model)
 | 
			
		||||
| 
						 | 
				
			
			@ -219,6 +204,62 @@ class _BaseAutoModelClass:
 | 
			
		|||
 | 
			
		||||
        return model
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def optimize_npu_model(cls, *args, **kwargs):
 | 
			
		||||
 | 
			
		||||
        from ipex_llm.transformers.npu_models.convert_mp import optimize_llm_pre, optimize_llm
 | 
			
		||||
        from intel_npu_acceleration_library.compiler import create_npu_kernels
 | 
			
		||||
 | 
			
		||||
        model = kwargs.pop("model")
 | 
			
		||||
        qtype = kwargs.pop("qtype", "sym_int4")
 | 
			
		||||
        mixed_precision = kwargs.pop("mixed_precision", False)
 | 
			
		||||
        quantization_group_size = kwargs.pop("quantization_group_size", 0)
 | 
			
		||||
        modules_to_not_convert = kwargs.pop("modules_to_not_convert", [])
 | 
			
		||||
        pipeline = kwargs.pop("pipeline", False)
 | 
			
		||||
        max_output_len = kwargs.pop("max_output_len", 1024)
 | 
			
		||||
        max_prompt_len = kwargs.pop("max_prompt_len", 512)
 | 
			
		||||
        inter_pp = kwargs.pop("inter_pp", None)
 | 
			
		||||
        intra_pp = kwargs.pop("intra_pp", None)
 | 
			
		||||
        transpose_value_cache = kwargs.pop("transpose_value_cache", True)
 | 
			
		||||
 | 
			
		||||
        if hasattr(model, "llm"):
 | 
			
		||||
            llm = model.llm
 | 
			
		||||
        else:
 | 
			
		||||
            llm = model
 | 
			
		||||
 | 
			
		||||
        with torch.no_grad():
 | 
			
		||||
            model.config.update({"mixed_precision": mixed_precision})
 | 
			
		||||
            model.config.update({"group_size": quantization_group_size})
 | 
			
		||||
            optimize_llm_pre(model, qtype, mixed_precision,
 | 
			
		||||
                             quantization_group_size=quantization_group_size)
 | 
			
		||||
            cls.load_convert(qtype, model, "cpu", modules_to_not_convert,
 | 
			
		||||
                             quantization_group_size, *args, **kwargs)
 | 
			
		||||
            create_npu_kernels(llm)
 | 
			
		||||
        model = model.eval()
 | 
			
		||||
        logger.info(f"Finish to convert model")
 | 
			
		||||
        model.config.update({"bigdl_transformers_low_bit": qtype})
 | 
			
		||||
        model.share_memory()
 | 
			
		||||
 | 
			
		||||
        if not pipeline:
 | 
			
		||||
            optimize_llm(
 | 
			
		||||
                llm,
 | 
			
		||||
                max_output_len=max_output_len,
 | 
			
		||||
                max_prompt_len=max_prompt_len,
 | 
			
		||||
                inter_pp=inter_pp,
 | 
			
		||||
                intra_pp=intra_pp,
 | 
			
		||||
                transpose_value_cache=transpose_value_cache,
 | 
			
		||||
                group_size=quantization_group_size
 | 
			
		||||
            )
 | 
			
		||||
            model.save_low_bit = types.MethodType(save_low_bit, model)
 | 
			
		||||
        else:
 | 
			
		||||
            from ipex_llm.transformers.npu_pipeline_model.convert_pipeline \
 | 
			
		||||
                import convert_llm
 | 
			
		||||
            convert_llm(llm,
 | 
			
		||||
                        kv_len=max_output_len,
 | 
			
		||||
                        transpose_value_cache=transpose_value_cache)
 | 
			
		||||
 | 
			
		||||
        return model
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def load_convert(cls, q_k, optimize_model, device, modules_to_not_convert,
 | 
			
		||||
                     group_size=0, *arg, **kwarg):
 | 
			
		||||
| 
						 | 
				
			
			@ -530,3 +571,52 @@ class AutoModelForMultipleChoice(_BaseAutoModelClass):
 | 
			
		|||
 | 
			
		||||
class AutoModelForTokenClassification(_BaseAutoModelClass):
 | 
			
		||||
    HF_Model = transformers.AutoModelForTokenClassification
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class FunAsrAutoModel(_BaseAutoModelClass):
 | 
			
		||||
    import funasr
 | 
			
		||||
    HF_Model = funasr.AutoModel
 | 
			
		||||
 | 
			
		||||
    def __init__(self, *args, **kwargs):
 | 
			
		||||
        self.model = self.from_pretrained(*args, **kwargs)
 | 
			
		||||
 | 
			
		||||
    def __getattr__(self, name):
 | 
			
		||||
        return getattr(self.model, name)
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def optimize_npu_model(cls, *args, **kwargs):
 | 
			
		||||
        from ipex_llm.transformers.npu_models.convert_mp import optimize_funasr
 | 
			
		||||
        from intel_npu_acceleration_library.compiler import create_npu_kernels
 | 
			
		||||
 | 
			
		||||
        model = kwargs.pop("model")
 | 
			
		||||
        qtype = kwargs.pop("qtype", "sym_int8")
 | 
			
		||||
        modules_to_not_convert = kwargs.pop("modules_to_not_convert", [])
 | 
			
		||||
        max_output_len = kwargs.pop("max_output_len", 1024)
 | 
			
		||||
        max_prompt_len = kwargs.pop("max_prompt_len", 512)
 | 
			
		||||
        inter_pp = kwargs.pop("inter_pp", None)
 | 
			
		||||
        intra_pp = kwargs.pop("intra_pp", None)
 | 
			
		||||
        transpose_value_cache = kwargs.pop("transpose_value_cache", True)
 | 
			
		||||
 | 
			
		||||
        encoders = model.model.encoder.encoders[0:31]
 | 
			
		||||
        decoders = model.model.decoder.decoders
 | 
			
		||||
        with torch.no_grad():
 | 
			
		||||
            cls.load_convert(qtype, encoders,
 | 
			
		||||
                             "cpu", modules_to_not_convert, *args, **kwargs)
 | 
			
		||||
            create_npu_kernels(encoders)
 | 
			
		||||
            cls.load_convert(qtype, decoders,
 | 
			
		||||
                             "cpu", modules_to_not_convert, *args, **kwargs)
 | 
			
		||||
            create_npu_kernels(decoders)
 | 
			
		||||
        logger.info(f"Finish to convert model")
 | 
			
		||||
        model.model.share_memory()
 | 
			
		||||
 | 
			
		||||
        optimize_funasr(
 | 
			
		||||
            model,
 | 
			
		||||
            max_output_len=max_output_len,
 | 
			
		||||
            max_prompt_len=max_prompt_len,
 | 
			
		||||
            inter_pp=inter_pp,
 | 
			
		||||
            intra_pp=intra_pp,
 | 
			
		||||
            transpose_value_cache=transpose_value_cache,
 | 
			
		||||
        )
 | 
			
		||||
        model.save_low_bit = types.MethodType(save_low_bit, model)
 | 
			
		||||
 | 
			
		||||
        return model
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -301,3 +301,43 @@ def optimize_llm(
 | 
			
		|||
 | 
			
		||||
    if isinstance(model.lm_head, SlicedLMHead):
 | 
			
		||||
        model.lm_head.get_fused_lm_head()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def optimize_funasr(
 | 
			
		||||
    model: torch.nn.Module,
 | 
			
		||||
    max_output_len=1024,
 | 
			
		||||
    max_prompt_len=1024,
 | 
			
		||||
    inter_pp=None,
 | 
			
		||||
    intra_pp=None,
 | 
			
		||||
    transpose_value_cache=True,
 | 
			
		||||
):
 | 
			
		||||
    if intra_pp is None:
 | 
			
		||||
        intra_pp = 2
 | 
			
		||||
    if inter_pp is None:
 | 
			
		||||
        inter_pp = 2
 | 
			
		||||
    from ipex_llm.transformers.npu_models.paraformer_mp import gen_funasr_fused_encoder_forward, \
 | 
			
		||||
        gen_funasr_fused_decoder_forward
 | 
			
		||||
    from ipex_llm.transformers.npu_models.paraformer_mp import PrefillRunner, DecodeRunner
 | 
			
		||||
    prefill_runner = PrefillRunner(
 | 
			
		||||
        model,
 | 
			
		||||
        max_output_len=max_output_len,
 | 
			
		||||
        max_prompt_len=max_prompt_len,
 | 
			
		||||
        transpose_value_cache=transpose_value_cache,
 | 
			
		||||
    )
 | 
			
		||||
    encoder_forward = gen_funasr_fused_encoder_forward(
 | 
			
		||||
        prefill_runner=prefill_runner
 | 
			
		||||
    )
 | 
			
		||||
    decode_runner = DecodeRunner(
 | 
			
		||||
        model,
 | 
			
		||||
        max_seq_len=max_output_len,
 | 
			
		||||
        inter_pp=inter_pp,
 | 
			
		||||
        intra_pp=intra_pp,
 | 
			
		||||
        transpose_value_cache=transpose_value_cache,
 | 
			
		||||
    )
 | 
			
		||||
    decoder_forward = gen_funasr_fused_decoder_forward(
 | 
			
		||||
        decode_runner=decode_runner
 | 
			
		||||
    )
 | 
			
		||||
    from funasr.models.sanm.encoder import SANMEncoder
 | 
			
		||||
    from funasr.models.paraformer.decoder import ParaformerSANMDecoder
 | 
			
		||||
    convert_forward(model.model, SANMEncoder, encoder_forward)
 | 
			
		||||
    convert_forward(model.model, ParaformerSANMDecoder, decoder_forward)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -80,6 +80,7 @@ def run_model(
 | 
			
		|||
    models = _model_cache.get(key, None)
 | 
			
		||||
 | 
			
		||||
    input_shapes = [elem.shape for elem in x_np]
 | 
			
		||||
 | 
			
		||||
    if models is None:
 | 
			
		||||
        _model_cache[key] = deque([backend_cls(*input_shapes) for i in range(replica)])
 | 
			
		||||
    elif len(models) < 1:
 | 
			
		||||
| 
						 | 
				
			
			@ -315,6 +316,165 @@ class LLMBaseNNFactory(NNFactory):
 | 
			
		|||
 | 
			
		||||
        return attn_output, new_key_states, new_value_states
 | 
			
		||||
 | 
			
		||||
    def paraformer_layer_norm(self, hidden_states, layernorm_weight, layernorm_bias):
 | 
			
		||||
        hidden_states = self.convert_to_fp32(hidden_states)
 | 
			
		||||
        mean_res = self.reduce_mean(hidden_states, -1, keep_dims=True,)
 | 
			
		||||
        variance = self.reduce_mean(
 | 
			
		||||
            self.power(hidden_states - mean_res, self.constant(np.array([[2]], dtype=np.float32))),
 | 
			
		||||
            -1,
 | 
			
		||||
            keep_dims=True,
 | 
			
		||||
        )
 | 
			
		||||
        eps = self.constant(1e-12)
 | 
			
		||||
        hidden_states = self.eltwise_div(hidden_states - mean_res,
 | 
			
		||||
                                         self.sqrt(self.eltwise_add(variance, eps)))
 | 
			
		||||
        layernorm_weight = self.convert_to_fp32(layernorm_weight)
 | 
			
		||||
        hidden_states = self.eltwise_mul(layernorm_weight, hidden_states)
 | 
			
		||||
        layernorm_bias = self.convert_to_fp32(layernorm_bias)
 | 
			
		||||
        hidden_states = self.eltwise_add(layernorm_bias, hidden_states)
 | 
			
		||||
        hidden_states = self.convert_to_fp16(hidden_states)
 | 
			
		||||
        return hidden_states
 | 
			
		||||
 | 
			
		||||
    def self_attn_sanm(self,
 | 
			
		||||
                       x,
 | 
			
		||||
                       mask,
 | 
			
		||||
                       in_feat,
 | 
			
		||||
                       n_feat,
 | 
			
		||||
                       n_head,
 | 
			
		||||
                       fsmn_weight,
 | 
			
		||||
                       qkv_bias,
 | 
			
		||||
                       out_bias
 | 
			
		||||
                       ):
 | 
			
		||||
        d_k = n_feat // n_head
 | 
			
		||||
        h = n_head
 | 
			
		||||
        b, t, d = x.size()
 | 
			
		||||
 | 
			
		||||
        q_k_v = self.linear(x,
 | 
			
		||||
                            1536,
 | 
			
		||||
                            512,
 | 
			
		||||
                            bias=False,
 | 
			
		||||
                            wt_dtype=self.dtype)
 | 
			
		||||
        q_k_v = self.eltwise_add(q_k_v, qkv_bias)
 | 
			
		||||
 | 
			
		||||
        q = self.slice(q_k_v, [0, 0, 0], [1, 218, 512])
 | 
			
		||||
        k = self.slice(q_k_v, [0, 0, 512], [1, 218, 2 * 512])
 | 
			
		||||
        v = self.slice(q_k_v, [0, 0, 2 * 512], [1, 218, 3 * 512])
 | 
			
		||||
 | 
			
		||||
        q_h = self.reshape(q, [b, t, h, d_k])
 | 
			
		||||
        k_h = self.reshape(k, [b, t, h, d_k])
 | 
			
		||||
        v_h = self.reshape(v, [b, t, h, d_k])
 | 
			
		||||
        q_h = self.transpose(q_h, [0, 2, 1, 3])
 | 
			
		||||
        k_h = self.transpose(k_h, [0, 2, 1, 3])
 | 
			
		||||
        v_h = self.transpose(v_h, [0, 2, 1, 3])
 | 
			
		||||
 | 
			
		||||
        b_v, t_v, d_v = v.size()
 | 
			
		||||
        if mask is not None:
 | 
			
		||||
            mask = self.reshape(mask, [b_v, -1, 1])
 | 
			
		||||
            v = self.eltwise_mul(v, mask)
 | 
			
		||||
        v_x = self.transpose(v, [0, 2, 1])
 | 
			
		||||
 | 
			
		||||
        fsmn_out = self.convolution(input_node=v_x,
 | 
			
		||||
                                    weights_node=fsmn_weight,
 | 
			
		||||
                                    bias=None,
 | 
			
		||||
                                    strides=1,
 | 
			
		||||
                                    padding=5,
 | 
			
		||||
                                    groups=512,
 | 
			
		||||
                                    n_spatial_dims=1)
 | 
			
		||||
 | 
			
		||||
        fsmn_out = self.transpose(fsmn_out, [0, 2, 1])
 | 
			
		||||
        fsmn_out = self.eltwise_add(v, fsmn_out)
 | 
			
		||||
        if mask is not None:
 | 
			
		||||
            fsmn_out = self.eltwise_mul(fsmn_out, mask)
 | 
			
		||||
 | 
			
		||||
        q_h = q_h * d_k ** (-0.5)
 | 
			
		||||
        scores = self.matmul(q_h, k_h, False, True)
 | 
			
		||||
        n_batch = v_h.size(0)
 | 
			
		||||
        p_attn = self.softmax(scores, -1)
 | 
			
		||||
 | 
			
		||||
        x_attn = self.matmul(p_attn, v_h, False, False)
 | 
			
		||||
        x_attn = self.transpose(x_attn, [0, 2, 1, 3])
 | 
			
		||||
        x_attn = self.reshape(x_attn, [n_batch, -1, h * d_k])
 | 
			
		||||
 | 
			
		||||
        attn_out = self.linear(x_attn,
 | 
			
		||||
                               512,
 | 
			
		||||
                               512,
 | 
			
		||||
                               bias=False,
 | 
			
		||||
                               wt_dtype=self.dtype)
 | 
			
		||||
        attn_out = attn_out + out_bias
 | 
			
		||||
        attn_res = self.eltwise_add(attn_out, fsmn_out)
 | 
			
		||||
        return attn_res
 | 
			
		||||
 | 
			
		||||
    def sanm_feed_forward(self, x, hidden_units, idim, w1_bias, w2_bias):
 | 
			
		||||
        mm = self.linear(x, 2048, 512, bias=False, wt_dtype=self.dtype)
 | 
			
		||||
        mm = mm + w1_bias
 | 
			
		||||
        mm_act = self.relu(mm)
 | 
			
		||||
        output = self.linear(mm_act, 512, 2048, bias=False, wt_dtype=self.dtype)
 | 
			
		||||
        output = output + w2_bias
 | 
			
		||||
        return output
 | 
			
		||||
 | 
			
		||||
    def multihead_attn_sanm_decoder(self, inputs, mask, fsmn_weight):
 | 
			
		||||
        b, t, d = inputs.size()
 | 
			
		||||
        if mask is not None:
 | 
			
		||||
            mask = self.reshape(mask, [b, -1, 1])
 | 
			
		||||
            inputs = self.eltwise_mul(inputs, mask)
 | 
			
		||||
        x = self.transpose(inputs, [0, 2, 1])
 | 
			
		||||
        b, d, t = x.size()
 | 
			
		||||
        cache = x
 | 
			
		||||
        fsmn_x = self.convolution(input_node=x,
 | 
			
		||||
                                  weights_node=fsmn_weight,
 | 
			
		||||
                                  bias=None,
 | 
			
		||||
                                  strides=1,
 | 
			
		||||
                                  padding=5,
 | 
			
		||||
                                  groups=512,
 | 
			
		||||
                                  n_spatial_dims=1)
 | 
			
		||||
        fsmn_x = self.transpose(fsmn_x, [0, 2, 1])
 | 
			
		||||
        x = self.eltwise_add(inputs, fsmn_x)
 | 
			
		||||
        if mask is not None:
 | 
			
		||||
            x = self.eltwise_mul(x, mask)
 | 
			
		||||
        return x, cache
 | 
			
		||||
 | 
			
		||||
    def sanm_cross_attn(self, x, memory, mask, q_bias, kv_bias, out_bias, n_feat, n_head):
 | 
			
		||||
        b = x.size(0)
 | 
			
		||||
        d_k = n_feat // n_head
 | 
			
		||||
        h = n_head
 | 
			
		||||
 | 
			
		||||
        q = self.linear(x, 512, 512, bias=False, wt_dtype=self.dtype)
 | 
			
		||||
        q = q + q_bias
 | 
			
		||||
        q_h = self.reshape(q, (b, -1, h, d_k))
 | 
			
		||||
        q_h = self.transpose(q_h, [0, 2, 1, 3])
 | 
			
		||||
 | 
			
		||||
        k_v = self.linear(memory, 1024, 512, bias=False, wt_dtype=self.dtype)
 | 
			
		||||
        k_v = k_v + kv_bias
 | 
			
		||||
 | 
			
		||||
        res = k_v.chunk(2, -1)
 | 
			
		||||
        k = res[0]
 | 
			
		||||
        v = res[1]
 | 
			
		||||
        k_h = self.reshape(k, [b, -1, h, d_k])
 | 
			
		||||
        v_h = self.reshape(v, [b, -1, h, d_k])
 | 
			
		||||
        k_h = self.transpose(k_h, [0, 2, 1, 3])
 | 
			
		||||
        v_h = self.transpose(v_h, [0, 2, 1, 3])
 | 
			
		||||
 | 
			
		||||
        q_h = q_h * d_k ** (-0.5)
 | 
			
		||||
        scores = self.matmul(q_h, k_h, False, True)
 | 
			
		||||
        n_batch = v_h.size(0)
 | 
			
		||||
        # Assume mask is None
 | 
			
		||||
        p_attn = self.softmax(scores, -1)
 | 
			
		||||
 | 
			
		||||
        v_h = self.transpose(v_h, [0, 1, 3, 2])
 | 
			
		||||
        x_attn = self.matmul(p_attn, v_h, False, True)
 | 
			
		||||
        x_attn = self.transpose(x_attn, [0, 2, 1, 3])
 | 
			
		||||
        x_attn = self.reshape(x_attn, [n_batch, -1, h * d_k])
 | 
			
		||||
        attn_out = self.linear(x_attn, 512, 512, bias=False, wt_dtype=self.dtype)
 | 
			
		||||
        attn_out = attn_out + out_bias
 | 
			
		||||
        return attn_out
 | 
			
		||||
 | 
			
		||||
    def feed_forward_sanm_decoder(self, x, w_1_bias, norm_weights, norm_bias):
 | 
			
		||||
        w_1 = self.linear(x, 2048, 512, bias=False, wt_dtype=self.dtype)
 | 
			
		||||
        w_1 = w_1 + w_1_bias
 | 
			
		||||
        w_1_act = self.relu(w_1)
 | 
			
		||||
        w_1_norm = self.paraformer_layer_norm(w_1_act, norm_weights, norm_bias)
 | 
			
		||||
        w_2 = self.linear(w_1_norm, 512, 2048, bias=False, wt_dtype=self.dtype)
 | 
			
		||||
        return w_2
 | 
			
		||||
 | 
			
		||||
    def mlp(self, hidden_states, seq_len=-1, mode="prefill"):
 | 
			
		||||
        if self.n_splits_linear == 1:
 | 
			
		||||
            mm1 = self.linear(
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										1065
									
								
								python/llm/src/ipex_llm/transformers/npu_models/paraformer_mp.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1065
									
								
								python/llm/src/ipex_llm/transformers/npu_models/paraformer_mp.py
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load diff
											
										
									
								
							
		Loading…
	
		Reference in a new issue