optimize npu llama2 perf again (#11445)
This commit is contained in:
		
							parent
							
								
									13f59ae6b4
								
							
						
					
					
						commit
						f89ca23748
					
				
					 2 changed files with 123 additions and 2 deletions
				
			
		| 
						 | 
					@ -31,6 +31,9 @@ def optimize_llm(model: torch.nn.Module):
 | 
				
			||||||
        model.apply(merge_qkv)
 | 
					        model.apply(merge_qkv)
 | 
				
			||||||
        from ipex_llm.transformers.npu_models.llama import merge_mlp
 | 
					        from ipex_llm.transformers.npu_models.llama import merge_mlp
 | 
				
			||||||
        model.apply(merge_mlp)
 | 
					        model.apply(merge_mlp)
 | 
				
			||||||
 | 
					        from ipex_llm.transformers.npu_models.llama import llama_model_forward
 | 
				
			||||||
 | 
					        from transformers.models.llama.modeling_llama import LlamaModel
 | 
				
			||||||
 | 
					        convert_forward(model, LlamaModel, llama_model_forward)
 | 
				
			||||||
        from ipex_llm.transformers.npu_models.llama import llama_attention_forward
 | 
					        from ipex_llm.transformers.npu_models.llama import llama_attention_forward
 | 
				
			||||||
        from transformers.models.llama.modeling_llama import LlamaAttention
 | 
					        from transformers.models.llama.modeling_llama import LlamaAttention
 | 
				
			||||||
        convert_forward(model, LlamaAttention, llama_attention_forward)
 | 
					        convert_forward(model, LlamaAttention, llama_attention_forward)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -32,13 +32,15 @@
 | 
				
			||||||
# limitations under the License.
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from typing import Optional, Tuple
 | 
					from typing import Optional, Tuple, List, Union
 | 
				
			||||||
from transformers.cache_utils import Cache
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
 | 
					from transformers.cache_utils import Cache
 | 
				
			||||||
 | 
					from transformers.modeling_outputs import BaseModelOutputWithPast
 | 
				
			||||||
from transformers.models.llama.modeling_llama import repeat_kv, apply_rotary_pos_emb
 | 
					from transformers.models.llama.modeling_llama import repeat_kv, apply_rotary_pos_emb
 | 
				
			||||||
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaMLP
 | 
					from transformers.models.llama.modeling_llama import LlamaAttention, LlamaMLP
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from ipex_llm.utils.common.log4Error import invalidInputError
 | 
				
			||||||
from ipex_llm.transformers.npu_models.common import merge_linear
 | 
					from ipex_llm.transformers.npu_models.common import merge_linear
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -63,6 +65,122 @@ def merge_mlp(module: torch.nn.Module):
 | 
				
			||||||
        del module.gate_proj, module.up_proj
 | 
					        del module.gate_proj, module.up_proj
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def llama_model_forward(
 | 
				
			||||||
 | 
					    self,
 | 
				
			||||||
 | 
					    input_ids: torch.LongTensor = None,
 | 
				
			||||||
 | 
					    attention_mask: Optional[torch.Tensor] = None,
 | 
				
			||||||
 | 
					    position_ids: Optional[torch.LongTensor] = None,
 | 
				
			||||||
 | 
					    past_key_values: Optional[List[torch.FloatTensor]] = None,
 | 
				
			||||||
 | 
					    inputs_embeds: Optional[torch.FloatTensor] = None,
 | 
				
			||||||
 | 
					    use_cache: Optional[bool] = None,
 | 
				
			||||||
 | 
					    output_attentions: Optional[bool] = None,
 | 
				
			||||||
 | 
					    output_hidden_states: Optional[bool] = None,
 | 
				
			||||||
 | 
					    return_dict: Optional[bool] = None,
 | 
				
			||||||
 | 
					    cache_position: Optional[torch.LongTensor] = None,
 | 
				
			||||||
 | 
					) -> Union[Tuple, BaseModelOutputWithPast]:
 | 
				
			||||||
 | 
					    output_attentions = (
 | 
				
			||||||
 | 
					        output_attentions if output_attentions is not None
 | 
				
			||||||
 | 
					        else self.config.output_attentions
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    output_hidden_states = (
 | 
				
			||||||
 | 
					        output_hidden_states if output_hidden_states is not None
 | 
				
			||||||
 | 
					        else self.config.output_hidden_states
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    use_cache = use_cache if use_cache is not None else self.config.use_cache
 | 
				
			||||||
 | 
					    return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if (input_ids is None) ^ (inputs_embeds is not None):
 | 
				
			||||||
 | 
					        invalidInputError(False,
 | 
				
			||||||
 | 
					                          ("You cannot specify both input_ids and inputs_embeds at the same time, "
 | 
				
			||||||
 | 
					                           "and must specify either one"))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if self.gradient_checkpointing and self.training and use_cache:
 | 
				
			||||||
 | 
					        use_cache = False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if inputs_embeds is None:
 | 
				
			||||||
 | 
					        inputs_embeds = self.embed_tokens(input_ids)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    past_seen_tokens = 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # ipex-llm changes start
 | 
				
			||||||
 | 
					    from ipex_llm.transformers.kv import DynamicNormalCache
 | 
				
			||||||
 | 
					    if use_cache and not isinstance(past_key_values, DynamicNormalCache):
 | 
				
			||||||
 | 
					        past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values)
 | 
				
			||||||
 | 
					        past_seen_tokens = past_key_values.set_seq_length()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if cache_position is None:
 | 
				
			||||||
 | 
					        cache_position = torch.arange(past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1],
 | 
				
			||||||
 | 
					                                      device=inputs_embeds.device)
 | 
				
			||||||
 | 
					    # ipex-llm changes end
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if position_ids is None:
 | 
				
			||||||
 | 
					        position_ids = cache_position.unsqueeze(0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    causal_mask = self._update_causal_mask(attention_mask, inputs_embeds,
 | 
				
			||||||
 | 
					                                           cache_position, past_seen_tokens)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # embed positions
 | 
				
			||||||
 | 
					    hidden_states = inputs_embeds
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # decoder layers
 | 
				
			||||||
 | 
					    all_hidden_states = () if output_hidden_states else None
 | 
				
			||||||
 | 
					    all_self_attns = () if output_attentions else None
 | 
				
			||||||
 | 
					    next_decoder_cache = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    for decoder_layer in self.layers:
 | 
				
			||||||
 | 
					        if output_hidden_states:
 | 
				
			||||||
 | 
					            all_hidden_states += (hidden_states,)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.gradient_checkpointing and self.training:
 | 
				
			||||||
 | 
					            layer_outputs = self._gradient_checkpointing_func(
 | 
				
			||||||
 | 
					                decoder_layer.__call__,
 | 
				
			||||||
 | 
					                hidden_states,
 | 
				
			||||||
 | 
					                causal_mask,
 | 
				
			||||||
 | 
					                position_ids,
 | 
				
			||||||
 | 
					                past_key_values,
 | 
				
			||||||
 | 
					                output_attentions,
 | 
				
			||||||
 | 
					                use_cache,
 | 
				
			||||||
 | 
					                cache_position,
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            layer_outputs = decoder_layer(
 | 
				
			||||||
 | 
					                hidden_states,
 | 
				
			||||||
 | 
					                attention_mask=causal_mask,
 | 
				
			||||||
 | 
					                position_ids=position_ids,
 | 
				
			||||||
 | 
					                past_key_value=past_key_values,
 | 
				
			||||||
 | 
					                output_attentions=output_attentions,
 | 
				
			||||||
 | 
					                use_cache=use_cache,
 | 
				
			||||||
 | 
					                cache_position=cache_position,
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        hidden_states = layer_outputs[0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if use_cache:
 | 
				
			||||||
 | 
					            next_decoder_cache = layer_outputs[2 if output_attentions else 1]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if output_attentions:
 | 
				
			||||||
 | 
					            all_self_attns += (layer_outputs[1],)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    hidden_states = self.norm(hidden_states)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # add hidden states from the last decoder layer
 | 
				
			||||||
 | 
					    if output_hidden_states:
 | 
				
			||||||
 | 
					        all_hidden_states += (hidden_states,)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # ipex-llm changes start
 | 
				
			||||||
 | 
					    next_cache = next_decoder_cache if use_cache else None
 | 
				
			||||||
 | 
					    # ipex-llm changes end
 | 
				
			||||||
 | 
					    if not return_dict:
 | 
				
			||||||
 | 
					        return tuple(v for v in [hidden_states, next_cache,
 | 
				
			||||||
 | 
					                                 all_hidden_states, all_self_attns] if v is not None)
 | 
				
			||||||
 | 
					    return BaseModelOutputWithPast(
 | 
				
			||||||
 | 
					        last_hidden_state=hidden_states,
 | 
				
			||||||
 | 
					        past_key_values=next_cache,
 | 
				
			||||||
 | 
					        hidden_states=all_hidden_states,
 | 
				
			||||||
 | 
					        attentions=all_self_attns,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def llama_attention_forward(
 | 
					def llama_attention_forward(
 | 
				
			||||||
    self,
 | 
					    self,
 | 
				
			||||||
    hidden_states: torch.Tensor,
 | 
					    hidden_states: torch.Tensor,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue