ipex-llm/python/llm/src/ipex_llm/transformers/models/mllama.py
2024-12-24 14:18:31 +08:00

276 lines
11 KiB
Python

#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Some parts of this file is adapted from
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/mllama/modeling_mllama.py
# which is licensed under Apache License 2.0:
#
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from typing import Optional, Tuple, Union
from transformers.cache_utils import Cache
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.mllama.modeling_mllama import MllamaVisionAttention
from transformers.models.mllama.modeling_mllama import MllamaTextSelfAttention
from ipex_llm.transformers.models.utils import use_quantize_kv_cache
from ipex_llm.transformers.models.utils import should_use_fuse_rope
from ipex_llm.transformers.models.common import merge_qkv_base, attention_softmax
from ipex_llm.transformers.models.common import scaled_dot_product_attention
from ipex_llm.transformers.kv import DynamicNormalCache, DynamicFp8Cache
from ipex_llm.transformers.utils import invalidInputError
def merge_qkv(module: torch.nn.Module):
merge_qkv_base(module, MllamaVisionAttention)
merge_qkv_base(module, MllamaTextSelfAttention)
def mllama_vision_attention_forward(
self,
hidden_state: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: bool = None,
):
bsz, q_len, _ = hidden_state.size()
qkv = self.qkv_proj(hidden_state)
qkv = qkv.view(bsz, q_len, 3 * self.num_heads, self.head_dim)
qkv = qkv.transpose(1, 2)
query, key, value = qkv.chunk(3, dim=1)
attn_weights = None
attn_output = scaled_dot_product_attention(
query, key.contiguous(), value.contiguous(),
attention_softmax, False
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, -1)
output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return output, attn_weights
def mllama_text_model_forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
cross_attention_states: Optional[torch.FloatTensor] = None,
cross_attention_mask: Optional[torch.Tensor] = None,
full_text_row_masked_out_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = (
output_attentions if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
# IPEX-LLM OPT start: kv cache and quantize kv cache
inputs = input_ids if input_ids is not None else inputs_embeds
use_cache = True if inputs.device.type == "xpu" else use_cache
use_quantize_kv = use_quantize_kv_cache(
self.layers[0].mlp.down_proj, inputs,
self.config.num_attention_heads // self.config.num_key_value_heads
)
if use_cache:
if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
elif not use_quantize_kv and not isinstance(past_key_values, DynamicNormalCache):
past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values)
# IPEX-LLM OPT end
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
invalidInputError((input_ids is None) ^ (inputs_embeds is None),
"You cannot specify both input_ids and inputs_embeds at the same time, "
"and must specify either one")
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = inputs_embeds
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1],
device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)
# IPEX-LLM OPT start: use fused rope
if (should_use_fuse_rope(hidden_states, position_ids, False)
and self.rotary_emb.rope_type == "llama3"):
position_embeddings = self.rotary_emb.inv_freq
# IEPX_LLM OPT end
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = None
for idx, decoder_layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
# For text-only path we should skip cross attention layers.
# Let's check if the layer is cross attention layer and if we have cross attention states
# or cached cross attention states.
is_cross_attention_layer = idx in self.cross_attention_layers
# IPEX-LLM change start
if is_cross_attention_layer and cross_attention_states is None:
if past_key_values is None:
# use_cache=False
continue
elif len(past_key_values.key_cache) <= idx:
# first token but no cross_attention_states, means no image inputs
past_key_values.key_cache.append([])
past_key_values.value_cache.append([])
continue
elif past_key_values.key_cache[idx] == []:
# next token but no cross kv cache, means no image inputs
continue
# IPEX-LLM change end
layer_outputs = decoder_layer(
hidden_states,
cross_attention_states=cross_attention_states,
cross_attention_mask=cross_attention_mask,
attention_mask=causal_mask,
full_text_row_masked_out_mask=full_text_row_masked_out_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
def mllama_cross_attention_forward(
self,
hidden_states: torch.Tensor,
cross_attention_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Cache] = None,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
use_cache: bool = None,
cache_position: Optional[torch.LongTensor] = None,
):
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
query_states = self.q_norm(query_states.view(-1, self.head_dim))
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
if cross_attention_states is not None:
key_states = self.k_proj(cross_attention_states)
value_states = self.v_proj(cross_attention_states)
key_states = self.k_norm(key_states.view(-1, self.head_dim))
key_states = key_states.view(bsz, -1, self.num_key_value_heads,
self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, -1, self.num_key_value_heads,
self.head_dim).transpose(1, 2)
# if we have a new image + new tokens, we only computed key_states on that new image
# we still update the cross key states, past_image, new_image. And use it!
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, None
)
else:
key_states, value_states = (
past_key_value.key_cache[self.layer_idx],
past_key_value.value_cache[self.layer_idx],
)
attn_weights = None
attn_output = scaled_dot_product_attention(
query_states, key_states, value_states,
attention_mask, q_len == key_states.size(2)
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, -1)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value