diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index c3893669..28b927b2 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -1023,6 +1023,9 @@ def _optimize_pre(model, qtype=None): if model.config.model_type == "llama": from ipex_llm.transformers.models.llama import merge_qkv model.apply(merge_qkv) + if model.config.model_type == "mllama": + from ipex_llm.transformers.models.mllama import merge_qkv + model.apply(merge_qkv) if model.config.model_type == "minicpm": from ipex_llm.transformers.models.minicpm import merge_qkv model.apply(merge_qkv) @@ -1284,12 +1287,19 @@ def _optimize_post(model, lightweight_bmm=False): # llama 3.2 vision modeling_module_name = model.__class__.__module__ module = importlib.import_module(modeling_module_name) - from ipex_llm.transformers.models.common import rms_norm_forward - from ipex_llm.transformers.models.common import mlp_silu_forward from ipex_llm.transformers.models.mllama import mllama_vision_attention_forward convert_forward(model, module.MllamaVisionAttention, mllama_vision_attention_forward) + + from ipex_llm.transformers.models.common import rms_norm_forward + from ipex_llm.transformers.models.common import mlp_silu_forward + from ipex_llm.transformers.models.llama32 import llama_attention_forward + from ipex_llm.transformers.models.mllama import mllama_text_model_forward + from ipex_llm.transformers.models.mllama import mllama_cross_attention_forward convert_forward(model, module.MllamaTextRMSNorm, rms_norm_forward) convert_forward(model, module.MllamaTextMLP, mlp_silu_forward) + convert_forward(model, module.MllamaTextModel, mllama_text_model_forward) + convert_forward(model, module.MllamaTextSelfAttention, llama_attention_forward) + convert_forward(model, module.MllamaTextCrossAttention, mllama_cross_attention_forward) elif model.config.model_type == "llama": from transformers.models.llama.modeling_llama import LlamaRMSNorm from transformers.models.llama.modeling_llama import LlamaMLP diff --git a/python/llm/src/ipex_llm/transformers/models/mllama.py b/python/llm/src/ipex_llm/transformers/models/mllama.py index 9752ebe9..2f1142b7 100644 --- a/python/llm/src/ipex_llm/transformers/models/mllama.py +++ b/python/llm/src/ipex_llm/transformers/models/mllama.py @@ -35,8 +35,23 @@ import math import torch -from typing import Optional -from ipex_llm.transformers.models.utils import use_sdp_non_causal +from typing import Optional, Tuple, Union +from transformers.cache_utils import Cache +from transformers.modeling_outputs import BaseModelOutputWithPast +from transformers.models.mllama.modeling_mllama import MllamaVisionAttention +from transformers.models.mllama.modeling_mllama import MllamaTextSelfAttention +from transformers.models.mllama.modeling_mllama import repeat_kv +from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal, use_sdp_non_causal +from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache +from ipex_llm.transformers.models.utils import should_use_fuse_rope +from ipex_llm.transformers.models.common import merge_qkv_base, attention_softmax +from ipex_llm.transformers.kv import DynamicNormalCache, DynamicFp8Cache +from ipex_llm.transformers.utils import invalidInputError + + +def merge_qkv(module: torch.nn.Module): + merge_qkv_base(module, MllamaVisionAttention) + merge_qkv_base(module, MllamaTextSelfAttention) def mllama_vision_attention_forward( @@ -45,16 +60,12 @@ def mllama_vision_attention_forward( attention_mask: Optional[torch.Tensor] = None, output_attentions: bool = None, ): - query = self.q_proj(hidden_state) - key = self.k_proj(hidden_state) - value = self.v_proj(hidden_state) + bsz, q_len, _ = hidden_state.size() - batch_size, q_seq_len, _ = query.shape - _, kv_seq_len, _ = key.shape - - query = query.view(batch_size, q_seq_len, self.num_heads, self.head_dim).transpose(1, 2) - key = key.view(batch_size, kv_seq_len, self.num_heads, self.head_dim).transpose(1, 2) - value = value.view(batch_size, kv_seq_len, self.num_heads, self.head_dim).transpose(1, 2) + qkv = self.qkv_proj(hidden_state) + qkv = qkv.view(bsz, q_len, 3 * self.num_heads, self.head_dim) + qkv = qkv.transpose(1, 2) + query, key, value = qkv.chunk(3, dim=1) if attention_mask is not None: # no matter the length, we just slice it causal_mask = attention_mask[:, :, :, : key.shape[-2]] @@ -79,7 +90,7 @@ def mllama_vision_attention_forward( attn_output = torch.matmul(attn_weights, value) attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(batch_size, q_seq_len, -1) + attn_output = attn_output.reshape(bsz, q_len, -1) output = self.o_proj(attn_output) @@ -87,3 +98,227 @@ def mllama_vision_attention_forward( attn_weights = None return output, attn_weights + + +def mllama_text_model_forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + cross_attention_states: Optional[torch.FloatTensor] = None, + cross_attention_mask: Optional[torch.Tensor] = None, + full_text_row_masked_out_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, +) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = ( + output_attentions if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + # IPEX-LLM OPT start: kv cache and quantize kv cache + inputs = input_ids if input_ids is not None else inputs_embeds + use_cache = True if inputs.device.type == "xpu" else use_cache + use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs) + if use_cache: + if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache): + past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values) + elif not use_quantize_kv and not isinstance(past_key_values, DynamicNormalCache): + past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values) + # IPEX-LLM OPT end + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + invalidInputError((input_ids is None) ^ (inputs_embeds is None), + "You cannot specify both input_ids and inputs_embeds at the same time, " + "and must specify either one") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + hidden_states = inputs_embeds + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], + device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # IPEX-LLM OPT start: use fused rope + if (should_use_fuse_rope(hidden_states, position_ids, False) + and self.rotary_emb.rope_type == "llama3"): + position_embeddings = self.rotary_emb.inv_freq + # IEPX_LLM OPT end + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + # For text-only path we should skip cross attention layers. + # Let's check if the layer is cross attention layer and if we have cross attention states + # or cached cross attention states. + is_cross_attention_layer = idx in self.cross_attention_layers + + # IPEX-LLM change start + if is_cross_attention_layer and cross_attention_states is None: + if past_key_values is None: + # use_cache=False + continue + elif len(past_key_values.key_cache) <= idx: + # first token but no cross_attention_states, means no image inputs + past_key_values.key_cache.append([]) + past_key_values.value_cache.append([]) + continue + elif past_key_values.key_cache[idx] == []: + # next token but no cross kv cache, means no image inputs + continue + # IPEX-LLM change end + + layer_outputs = decoder_layer( + hidden_states, + cross_attention_states=cross_attention_states, + cross_attention_mask=cross_attention_mask, + attention_mask=causal_mask, + full_text_row_masked_out_mask=full_text_row_masked_out_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] + if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +def mllama_cross_attention_forward( + self, + hidden_states: torch.Tensor, + cross_attention_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Cache] = None, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + use_cache: bool = None, + cache_position: Optional[torch.LongTensor] = None, +): + bsz, q_len, _ = hidden_states.size() + query_states = self.q_proj(hidden_states) + query_states = self.q_norm(query_states.view(-1, self.head_dim)) + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + + if cross_attention_states is not None: + key_states = self.k_proj(cross_attention_states) + value_states = self.v_proj(cross_attention_states) + key_states = self.k_norm(key_states.view(-1, self.head_dim)) + key_states = key_states.view(bsz, -1, self.num_key_value_heads, + self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, -1, self.num_key_value_heads, + self.head_dim).transpose(1, 2) + + # if we have a new image + new tokens, we only computed key_states on that new image + # we still update the cross key states, past_image, new_image. And use it! + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, None + ) + else: + key_states, value_states = ( + past_key_value.key_cache[self.layer_idx], + past_key_value.value_cache[self.layer_idx], + ) + + kv_seq_len = key_states.size(2) + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, :kv_seq_len] + else: + causal_mask = None + + attn_weights = None + if use_sdp(q_len, kv_seq_len, self.head_dim, query_states): + import xe_addons + if isinstance(past_key_value, DynamicFp8Cache): + attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states, causal_mask) + else: + attn_output = xe_addons.sdp(query_states, key_states, value_states, causal_mask) + elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training): + import xe_addons + if isinstance(past_key_value, DynamicFp8Cache): + attn_output = xe_addons.sdp_fp8_causal(query_states, key_states, + value_states, causal_mask) + else: + attn_output = xe_addons.sdp_causal(query_states, key_states, + value_states, causal_mask) + else: + if isinstance(past_key_value, DynamicFp8Cache): + key_states, value_states = restore_fp8_kv_cache(key_states, value_states, + query_states.dtype) + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, + key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if causal_mask is not None: + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = attention_softmax(attn_weights, self.training) + attn_output = torch.matmul(attn_weights, value_states) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, -1) + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value