optimize npu llama2 perf again (#11445)

This commit is contained in:
Yishuo Wang 2024-06-27 15:13:42 +08:00 committed by GitHub
parent 13f59ae6b4
commit f89ca23748
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 123 additions and 2 deletions

View file

@ -31,6 +31,9 @@ def optimize_llm(model: torch.nn.Module):
model.apply(merge_qkv)
from ipex_llm.transformers.npu_models.llama import merge_mlp
model.apply(merge_mlp)
from ipex_llm.transformers.npu_models.llama import llama_model_forward
from transformers.models.llama.modeling_llama import LlamaModel
convert_forward(model, LlamaModel, llama_model_forward)
from ipex_llm.transformers.npu_models.llama import llama_attention_forward
from transformers.models.llama.modeling_llama import LlamaAttention
convert_forward(model, LlamaAttention, llama_attention_forward)

View file

@ -32,13 +32,15 @@
# limitations under the License.
from typing import Optional, Tuple
from transformers.cache_utils import Cache
from typing import Optional, Tuple, List, Union
import torch
from transformers.cache_utils import Cache
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.llama.modeling_llama import repeat_kv, apply_rotary_pos_emb
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaMLP
from ipex_llm.utils.common.log4Error import invalidInputError
from ipex_llm.transformers.npu_models.common import merge_linear
@ -63,6 +65,122 @@ def merge_mlp(module: torch.nn.Module):
del module.gate_proj, module.up_proj
def llama_model_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = (
output_attentions if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if (input_ids is None) ^ (inputs_embeds is not None):
invalidInputError(False,
("You cannot specify both input_ids and inputs_embeds at the same time, "
"and must specify either one"))
if self.gradient_checkpointing and self.training and use_cache:
use_cache = False
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
past_seen_tokens = 0
# ipex-llm changes start
from ipex_llm.transformers.kv import DynamicNormalCache
if use_cache and not isinstance(past_key_values, DynamicNormalCache):
past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values)
past_seen_tokens = past_key_values.set_seq_length()
if cache_position is None:
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1],
device=inputs_embeds.device)
# ipex-llm changes end
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds,
cache_position, past_seen_tokens)
# embed positions
hidden_states = inputs_embeds
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = None
for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
causal_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
# ipex-llm changes start
next_cache = next_decoder_cache if use_cache else None
# ipex-llm changes end
if not return_dict:
return tuple(v for v in [hidden_states, next_cache,
all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
def llama_attention_forward(
self,
hidden_states: torch.Tensor,