optimize llama npu perf (#11426)
This commit is contained in:
		
							parent
							
								
									e473b8d946
								
							
						
					
					
						commit
						9f6e5b4fba
					
				
					 4 changed files with 190 additions and 1 deletions
				
			
		| 
						 | 
				
			
			@ -27,6 +27,7 @@ import intel_npu_acceleration_library as npu_lib
 | 
			
		|||
 | 
			
		||||
from ipex_llm.utils.common.log4Error import invalidInputError
 | 
			
		||||
from ipex_llm.transformers.utils import logger
 | 
			
		||||
from ipex_llm.transformers.npu_models.convert import optimize_llm
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def patch_flash_attn_import(filename: str) -> List[str]:
 | 
			
		||||
| 
						 | 
				
			
			@ -112,7 +113,23 @@ class _BaseAutoModelClass:
 | 
			
		|||
        model = cls.HF_Model.from_pretrained(*args, **kwargs)
 | 
			
		||||
 | 
			
		||||
        logger.info(f"Converting model, it may takes up to several minutes ...")
 | 
			
		||||
        model = npu_lib.compile(model, qtype, False)
 | 
			
		||||
        try:
 | 
			
		||||
            # for intel_npu_acceleration_library >= 1.1.0
 | 
			
		||||
            from intel_npu_acceleration_library.quantization import quantize_model
 | 
			
		||||
            from intel_npu_acceleration_library.compiler import (
 | 
			
		||||
                apply_horizontal_fusion, create_npu_kernels
 | 
			
		||||
            )
 | 
			
		||||
            with torch.no_grad():
 | 
			
		||||
                optimize_llm(model)
 | 
			
		||||
                apply_horizontal_fusion(model)
 | 
			
		||||
                if not qtype.is_floating_point:
 | 
			
		||||
                    model = quantize_model(model, qtype)
 | 
			
		||||
                create_npu_kernels(model)
 | 
			
		||||
            model = model.eval()
 | 
			
		||||
        except ImportError as _e:
 | 
			
		||||
            # for intel_npu_acceleration_library < 1.1.0
 | 
			
		||||
            model = npu_lib.compile(model, qtype, False)
 | 
			
		||||
        logger.info(f"Finish to convert model")
 | 
			
		||||
 | 
			
		||||
        # add save_low_bit to pretrained model dynamically
 | 
			
		||||
        model.save_low_bit = types.MethodType(cls.save_low_bit, model)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										15
									
								
								python/llm/src/ipex_llm/transformers/npu_models/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										15
									
								
								python/llm/src/ipex_llm/transformers/npu_models/__init__.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,15 @@
 | 
			
		|||
#
 | 
			
		||||
# Copyright 2016 The BigDL Authors.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										34
									
								
								python/llm/src/ipex_llm/transformers/npu_models/convert.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										34
									
								
								python/llm/src/ipex_llm/transformers/npu_models/convert.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,34 @@
 | 
			
		|||
#
 | 
			
		||||
# Copyright 2016 The BigDL Authors.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def convert_forward(m, target_m, new_forward):
 | 
			
		||||
    if m.__class__ == target_m:
 | 
			
		||||
        bound_method = new_forward.__get__(m, m.__class__)
 | 
			
		||||
        setattr(m, "forward", bound_method)
 | 
			
		||||
    for _, sub_m in m.named_children():
 | 
			
		||||
        convert_forward(sub_m, target_m, new_forward)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def optimize_llm(model: torch.nn.Module):
 | 
			
		||||
    if model.config.model_type == "llama":
 | 
			
		||||
        from ipex_llm.transformers.npu_models.llama import merge_qkv
 | 
			
		||||
        model.apply(merge_qkv)
 | 
			
		||||
        from ipex_llm.transformers.npu_models.llama import llama_attention_forward
 | 
			
		||||
        from transformers.models.llama.modeling_llama import LlamaAttention
 | 
			
		||||
        convert_forward(model, LlamaAttention, llama_attention_forward)
 | 
			
		||||
							
								
								
									
										123
									
								
								python/llm/src/ipex_llm/transformers/npu_models/llama.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										123
									
								
								python/llm/src/ipex_llm/transformers/npu_models/llama.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,123 @@
 | 
			
		|||
#
 | 
			
		||||
# Copyright 2016 The BigDL Authors.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
#
 | 
			
		||||
# Some parts of this file is adapted from
 | 
			
		||||
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py
 | 
			
		||||
# which is licensed under Apache License 2.0:
 | 
			
		||||
#
 | 
			
		||||
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
from typing import Optional, Tuple
 | 
			
		||||
from transformers.cache_utils import Cache
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from transformers.models.llama.modeling_llama import LlamaAttention, repeat_kv, apply_rotary_pos_emb
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def merge_qkv(module: torch.nn.Module):
 | 
			
		||||
    if isinstance(module, LlamaAttention):
 | 
			
		||||
        new_weight = torch.cat([
 | 
			
		||||
            module.q_proj.weight.data,
 | 
			
		||||
            module.k_proj.weight.data,
 | 
			
		||||
            module.v_proj.weight.data,
 | 
			
		||||
        ], dim=0)
 | 
			
		||||
 | 
			
		||||
        if module.q_proj.bias is not None:
 | 
			
		||||
            qkv_proj = torch.nn.Linear(0, 0, bias=True)
 | 
			
		||||
            new_bias = torch.cat([
 | 
			
		||||
                module.q_proj.bias.data,
 | 
			
		||||
                module.k_proj.bias.data,
 | 
			
		||||
                module.v_proj.bias.data,
 | 
			
		||||
            ], dim=0)
 | 
			
		||||
            qkv_proj.bias = torch.nn.Parameter(new_bias, requires_grad=False)
 | 
			
		||||
        else:
 | 
			
		||||
            qkv_proj = torch.nn.Linear(0, 0, bias=False)
 | 
			
		||||
        qkv_proj.weight = torch.nn.Parameter(new_weight, requires_grad=False)
 | 
			
		||||
        qkv_proj.in_features = new_weight.size(1)
 | 
			
		||||
        qkv_proj.out_features = new_weight.size(0)
 | 
			
		||||
        module.qkv_proj = qkv_proj
 | 
			
		||||
 | 
			
		||||
        del module.q_proj, module.k_proj, module.v_proj
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def llama_attention_forward(
 | 
			
		||||
    self,
 | 
			
		||||
    hidden_states: torch.Tensor,
 | 
			
		||||
    attention_mask: Optional[torch.Tensor] = None,
 | 
			
		||||
    position_ids: Optional[torch.LongTensor] = None,
 | 
			
		||||
    past_key_value: Optional[Cache] = None,
 | 
			
		||||
    output_attentions: bool = False,
 | 
			
		||||
    use_cache: bool = False,
 | 
			
		||||
    cache_position: Optional[torch.LongTensor] = None,
 | 
			
		||||
    **kwargs,
 | 
			
		||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
 | 
			
		||||
    bsz, q_len, _ = hidden_states.size()
 | 
			
		||||
 | 
			
		||||
    qkv = self.qkv_proj(hidden_states)
 | 
			
		||||
    qkv = qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim)
 | 
			
		||||
    qkv = qkv.transpose(1, 2)
 | 
			
		||||
    query_states, key_states, value_states = qkv.split([self.num_heads,
 | 
			
		||||
                                                        self.num_key_value_heads,
 | 
			
		||||
                                                        self.num_key_value_heads], dim=1)
 | 
			
		||||
 | 
			
		||||
    past_key_value = getattr(self, "past_key_value", past_key_value)
 | 
			
		||||
    cos, sin = self.rotary_emb(value_states, position_ids)
 | 
			
		||||
    query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
 | 
			
		||||
 | 
			
		||||
    if past_key_value is not None:
 | 
			
		||||
        # sin and cos are specific to RoPE models; cache_position needed for the static cache
 | 
			
		||||
        cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
 | 
			
		||||
        key_states, value_states = past_key_value.update(key_states, value_states,
 | 
			
		||||
                                                         self.layer_idx, cache_kwargs)
 | 
			
		||||
 | 
			
		||||
    key_states = repeat_kv(key_states, self.num_key_value_groups)
 | 
			
		||||
    value_states = repeat_kv(value_states, self.num_key_value_groups)
 | 
			
		||||
 | 
			
		||||
    if attention_mask is not None:
 | 
			
		||||
        causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
 | 
			
		||||
    else:
 | 
			
		||||
        causal_mask = None
 | 
			
		||||
 | 
			
		||||
    attn_output = torch.nn.functional.scaled_dot_product_attention(
 | 
			
		||||
        query_states,
 | 
			
		||||
        key_states,
 | 
			
		||||
        value_states,
 | 
			
		||||
        attn_mask=causal_mask,
 | 
			
		||||
        is_causal=self.is_causal and attention_mask is None and q_len > 1,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    attn_output = attn_output.transpose(1, 2).contiguous()
 | 
			
		||||
 | 
			
		||||
    attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
 | 
			
		||||
 | 
			
		||||
    attn_output = self.o_proj(attn_output)
 | 
			
		||||
 | 
			
		||||
    if not output_attentions:
 | 
			
		||||
        attn_weights = None
 | 
			
		||||
 | 
			
		||||
    return attn_output, attn_weights, past_key_value
 | 
			
		||||
		Loading…
	
		Reference in a new issue