From ddc0ef3993c3bdb047152d51ec5569a9954d799e Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Tue, 7 Jan 2025 11:15:51 +0800 Subject: [PATCH] refactor device check and remove cohere/mixtral support (#12659) --- .../llm/src/ipex_llm/transformers/convert.py | 82 +-- .../llm/src/ipex_llm/transformers/lookup.py | 4 +- .../ipex_llm/transformers/low_bit_linear.py | 16 +- .../ipex_llm/transformers/models/cohere.py | 589 ------------------ .../ipex_llm/transformers/models/minicpmv.py | 4 +- .../ipex_llm/transformers/models/mixtral.py | 576 ----------------- .../src/ipex_llm/transformers/models/sd.py | 4 +- .../src/ipex_llm/transformers/models/utils.py | 103 +-- python/llm/src/ipex_llm/transformers/utils.py | 25 +- 9 files changed, 44 insertions(+), 1359 deletions(-) delete mode 100644 python/llm/src/ipex_llm/transformers/models/cohere.py delete mode 100644 python/llm/src/ipex_llm/transformers/models/mixtral.py diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index 807581f6..8979b5ca 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -1710,31 +1710,6 @@ def _optimize_post(model): convert_forward(model, module.VisionAttention, qwen2_vision_attention_forward) convert_forward(model, module.Qwen2VLModel, qwen2_vl_model_forward) convert_forward(model, module.Qwen2VLAttention, qwen2_vl_attention_forward) - elif model.config.model_type == "cohere": - # for CohereForAI/c4ai-command-r-v01 - invalidInputError(version.parse(trans_version) >= version.parse("4.40.0"), - "Please upgrade transformers to 4.40.0 or higher version " - "to run Mixtral models.") - modeling_module_name = model.__class__.__module__ - module = importlib.import_module(modeling_module_name) - if version.parse(trans_version) >= version.parse("4.41.0"): - from ipex_llm.transformers.models.cohere import cohere_model_forward_4_41 - convert_forward(model, - module.CohereModel, - cohere_model_forward_4_41) - else: - from ipex_llm.transformers.models.cohere import cohere_model_forward - convert_forward(model, - module.CohereModel, - cohere_model_forward) - - from ipex_llm.transformers.models.cohere import cohere_attention_forward - convert_forward(model, - module.CohereAttention, - cohere_attention_forward) - convert_forward(model, - module.CohereMLP, - mlp_silu_forward) elif model.config.model_type == "aquila": modeling_module_name = model.__class__.__module__ module = importlib.import_module(modeling_module_name) @@ -1746,31 +1721,6 @@ def _optimize_post(model): convert_forward(model, module.AquilaRMSNorm, rms_norm_forward) - elif model.config.model_type == "mixtral": - # For mistralai/Mixtral-8x7B-v0.1 - invalidInputError(version.parse(trans_version) >= version.parse("4.36.0"), - "Please upgrade transformers to 4.36.0 or higher version " - "to run Mixtral models.") - modeling_module_name = model.__class__.__module__ - module = importlib.import_module(modeling_module_name) - from ipex_llm.transformers.models.mixtral import mixtral_moeblock_forward, \ - mixtral_attention_forward, mixtral_mlp_forward, mixtral_model_forward - convert_forward(model, - module.MixtralAttention, - mixtral_attention_forward) - convert_forward(model, - module.MixtralRMSNorm, - rms_norm_forward) - convert_forward(model, - module.MixtralSparseMoeBlock, - mixtral_moeblock_forward) - convert_forward(model, - module.MixtralBLockSparseTop2MLP, - mixtral_mlp_forward) - convert_forward(model, - module.MixtralModel, - mixtral_model_forward) - elif model.config.model_type == "phi-msft" and \ hasattr(model.config, "num_local_experts"): # For phixtral, limit the condition to avoid applying on phi-2 hosted by ModelScope @@ -1785,29 +1735,19 @@ def _optimize_post(model): module.MLP, phixtral_mlp_forward) elif model.config.model_type == "mistral": - if model.config.architectures is not None and \ - model.config.architectures[0] == "MixtralForCausalLM": - # For DiscoResearch/mixtral-7b-8expert - invalidInputError(version.parse(trans_version) >= version.parse("4.36.0"), - "Please upgrade transformers to 4.36.0 or higher version " - "to run Mixtral models.") - modeling_module_name = model.__class__.__module__ - module = importlib.import_module(modeling_module_name) - convert_forward(model, module.MistralRMSNorm, rms_norm_forward) - else: - modeling_module_name = model.__class__.__module__ - module = importlib.import_module(modeling_module_name) + modeling_module_name = model.__class__.__module__ + module = importlib.import_module(modeling_module_name) - from ipex_llm.transformers.models.mistral import mistral_model_forward - from ipex_llm.transformers.models.mistral import mistral_attention_forward - from ipex_llm.transformers.models.common import rms_norm_forward - from ipex_llm.transformers.models.common import mlp_silu_forward + from ipex_llm.transformers.models.mistral import mistral_model_forward + from ipex_llm.transformers.models.mistral import mistral_attention_forward + from ipex_llm.transformers.models.common import rms_norm_forward + from ipex_llm.transformers.models.common import mlp_silu_forward - convert_forward(model, module.MistralModel, mistral_model_forward) - convert_forward(model, module.MistralAttention, mistral_attention_forward) - convert_forward(model, module.MistralSdpaAttention, mistral_attention_forward) - convert_forward(model, module.MistralRMSNorm, rms_norm_forward) - convert_forward(model, module.MistralMLP, mlp_silu_forward) + convert_forward(model, module.MistralModel, mistral_model_forward) + convert_forward(model, module.MistralAttention, mistral_attention_forward) + convert_forward(model, module.MistralSdpaAttention, mistral_attention_forward) + convert_forward(model, module.MistralRMSNorm, rms_norm_forward) + convert_forward(model, module.MistralMLP, mlp_silu_forward) elif model.config.model_type == "gemma": modeling_module_name = model.__class__.__module__ module = importlib.import_module(modeling_module_name) diff --git a/python/llm/src/ipex_llm/transformers/lookup.py b/python/llm/src/ipex_llm/transformers/lookup.py index 08643e6e..9e9011cb 100644 --- a/python/llm/src/ipex_llm/transformers/lookup.py +++ b/python/llm/src/ipex_llm/transformers/lookup.py @@ -33,7 +33,7 @@ from ipex_llm.transformers.speculative import greedy, deepmind_sample, logits_to _crop_past_key_values, _prepare_generate_args, _non_cpu_ipex_verify, clear_benchmarks,\ _prepare_generate_args_4_45 from ipex_llm.utils.common import invalidInputError -from ipex_llm.transformers.utils import get_xpu_device_type +from ipex_llm.transformers.utils import get_xpu_device_name logger = logging.getLogger("ipex_llm.lookup") @@ -295,7 +295,7 @@ def lookup_generate(self, invalidInputError(input_ids.shape[0] == 1, "Prompt lookup is currently not supported with batch inference.") - device_name = get_xpu_device_type(input_ids) + device_name = get_xpu_device_name(input_ids.device) candidates_generator = PromptLookupCandidateGenerator( num_output_tokens=num_output_tokens, diff --git a/python/llm/src/ipex_llm/transformers/low_bit_linear.py b/python/llm/src/ipex_llm/transformers/low_bit_linear.py index 28e4c083..59d03c97 100644 --- a/python/llm/src/ipex_llm/transformers/low_bit_linear.py +++ b/python/llm/src/ipex_llm/transformers/low_bit_linear.py @@ -51,7 +51,7 @@ from torch import Tensor, device, dtype, nn from operator import mul from functools import reduce from ipex_llm.transformers.xpu_customize_fwd import custom_fwd, custom_bwd -from ipex_llm.transformers.utils import get_autocast_dtype, get_xpu_device_type, \ +from ipex_llm.transformers.utils import get_autocast_dtype, get_xpu_device_name, \ get_ipex_version from ipex_llm.transformers.convert import is_deepspeed_available, get_use_vllm @@ -266,7 +266,7 @@ def reshape_lm_head_input(x): def use_batch_forward(x: torch.Tensor, qtype: int, output_len: int): - device = get_xpu_device_type(x) + device_name = get_xpu_device_name(x.device) batch_size = x.shape[0] hard_condition = ( x.dtype in [torch.float, torch.half] @@ -286,7 +286,7 @@ def use_batch_forward(x: torch.Tensor, qtype: int, output_len: int): or ( qtype in [SYM_INT8, FP4, FP6, Q4_K, Q6_K] and batch_size <= 48 - and device in ["arc", "flex", "pvc", "mtl"] + and device_name in ["arc", "pvc", "mtl", "lnl", "arl"] and x.shape[1] % 256 == 0 and output_len % 32 == 0 ) @@ -295,8 +295,8 @@ def use_batch_forward(x: torch.Tensor, qtype: int, output_len: int): if hard_condition: return ( batch_size > 1 - or (device in ["arc", "flex"] and qtype in [SYM_INT8, FP4]) - or (device in ["arc", "flex", "mtl"] and qtype in [FP8E4]) + or (device in ["arc"] and qtype in [SYM_INT8, FP4]) + or (device in ["arc", "mtl"] and qtype in [FP8E4]) or (device in ["lnl"] and qtype in [SYM_INT4] and x.shape[1] % 512 == 0) or (device in ["bmg"] and qtype in [SYM_INT4, FP8E5]) ) @@ -603,7 +603,7 @@ class LowBitLinear(nn.Linear): # empty cache before and after lm_head at first token when input > 1024 # on arc or IPEX_LLM_LOW_MEM is set to 1 at inference time. if self.device is None: - self.device = get_xpu_device_type(self.weight.data) + self.device = get_xpu_device_name(self.weight.data.device) self.low_memory_mode = \ self.low_memory_mode and \ (self.device == "arc" or os.environ.get("IPEX_LLM_LOW_MEM", None) == "1") @@ -782,7 +782,7 @@ class FP16Linear(nn.Linear): if not self.use_esimd_kernel(x): if ( get_ipex_version() < "2.1.10+xpu" - or get_xpu_device_type(x) not in ["arc", "flex", "pvc"] + or get_xpu_device_name(x.device) not in ["arc", "pvc"] or self.disable_fp16_opt ): if self.weight_type == 2: @@ -848,7 +848,7 @@ class FP16Linear(nn.Linear): return result.to(x.dtype) def use_esimd_kernel(self, x): - gpu_type = get_xpu_device_type(x) + gpu_type = get_xpu_device_name(x.device) if self.disable_fp16_opt: return False # esimd kernel can only be used for Arc and Flex diff --git a/python/llm/src/ipex_llm/transformers/models/cohere.py b/python/llm/src/ipex_llm/transformers/models/cohere.py deleted file mode 100644 index c37e2ea6..00000000 --- a/python/llm/src/ipex_llm/transformers/models/cohere.py +++ /dev/null @@ -1,589 +0,0 @@ -# -# Copyright 2016 The BigDL Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# Some parts of this file is adapted from -# https://github.com/huggingface/transformers/blob/main/src/transformers/models/cohere/modeling_cohere.py - -# coding=utf-8 -# Copyright 2024 Cohere team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This file is based on the LLama model definition file in transformers - -"""PyTorch Cohere model.""" -import math -import torch -import torch.nn.functional as F -import torch.nn as nn -import torch.utils.checkpoint -from typing import Optional, Tuple, List -from ipex_llm.transformers.models.utils import repeat_kv -from ipex_llm.transformers.models.utils import extend_kv_cache, append_kv_cache -from transformers.models.cohere.modeling_cohere import apply_rotary_pos_emb -from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_36 -from ipex_llm.transformers.models.utils import use_decoding_fast_path -from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp -from transformers.models.cohere.modeling_cohere import apply_rotary_pos_emb -from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache -from ipex_llm.transformers.kv import DynamicFp8Cache -from ipex_llm.transformers.models.utils import should_use_fuse_rope -from transformers.modeling_outputs import BaseModelOutputWithPast -from ipex_llm.utils.common import invalidInputError -try: - from transformers.cache_utils import Cache, DynamicCache -except ImportError: - Cache = Tuple[torch.Tensor] - -KV_CACHE_ALLOC_BLOCK_LENGTH = 256 - - -def cohere_model_forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, -): - use_cache = use_cache if use_cache is not None \ - else self.config.use_cache - if use_cache and use_quantize_kv_cache(self.layers[0].mlp.up_proj, input_ids): - if not isinstance(past_key_values, DynamicFp8Cache): - past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values) - output_attentions = output_attentions if output_attentions is not None \ - else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None - else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if input_ids is not None and inputs_embeds is not None: - invalidInputError(False, - "You cannot specify both input_ids and inputs_embeds at the same time") - - if self.gradient_checkpointing and self.training and use_cache: - invalidInputError(False, - "`use_cache=True` is incompatible " - "with gradient checkpointing. Setting `use_cache=False`.") - use_cache = False - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - - past_seen_tokens = 0 - if use_cache: # kept for BC (cache positions) - if not isinstance(past_key_values, Cache): - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_seen_tokens = past_key_values.get_seq_length() - - if cache_position is None: - if isinstance(past_key_values, Cache): - invalidInputError(False, "cache_position is a required argument when using Cache.") - cache_position = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device - ) - - if position_ids is None: - position_ids = cache_position.unsqueeze(0) - - causal_mask = self._update_causal_mask(attention_mask, - inputs_embeds, cache_position, past_seen_tokens) - - # embed positions - hidden_states = inputs_embeds - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = None - - for decoder_layer in self.layers: - if output_hidden_states: - all_hidden_states += (hidden_states,) - - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - causal_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - ) - else: - # ipex-llm changes - curr_device = decoder_layer.input_layernorm.weight.device - if causal_mask is not None: - causal_mask = causal_mask.to(curr_device) - if position_ids is not None: - position_ids = position_ids.to(curr_device) - # ipex-llm changes end - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - ) - - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - hidden_states = self.norm(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - next_cache = next_decoder_cache if use_cache else None - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, - all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - - -def cohere_model_forward_4_41( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, -): - use_cache = use_cache if use_cache is not None \ - else self.config.use_cache - if use_cache and use_quantize_kv_cache(self.layers[0].mlp.up_proj, input_ids): - if not isinstance(past_key_values, DynamicFp8Cache): - past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values) - output_attentions = output_attentions if output_attentions is not None \ - else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None - else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if input_ids is not None and inputs_embeds is not None: - invalidInputError(False, - "You cannot specify both input_ids and inputs_embeds at the same time") - - if self.gradient_checkpointing and self.training and use_cache: - invalidInputError(False, - "`use_cache=True` is incompatible " - "with gradient checkpointing. Setting `use_cache=False`.") - use_cache = False - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - - past_seen_tokens = 0 - return_legacy_cache = False - # kept for BC (non `Cache` `past_key_values` inputs) - if use_cache and not isinstance(past_key_values, Cache): - return_legacy_cache = True - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - - if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - cache_position = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device - ) - - if position_ids is None: - position_ids = cache_position.unsqueeze(0) - - causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions - ) - - # embed positions - hidden_states = inputs_embeds - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = None - - for decoder_layer in self.layers: - if output_hidden_states: - all_hidden_states += (hidden_states,) - - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - causal_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - ) - else: - # ipex-llm changes - curr_device = decoder_layer.input_layernorm.weight.device - if causal_mask is not None: - causal_mask = causal_mask.to(curr_device) - if position_ids is not None: - position_ids = position_ids.to(curr_device) - # ipex-llm changes end - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - ) - - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - hidden_states = self.norm(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - next_cache = next_decoder_cache if use_cache else None - if return_legacy_cache: - next_cache = next_cache.to_legacy_cache() - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, - all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - - -def cohere_attention_forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - **kwargs, -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if use_quantize_kv_cache(self.q_proj, hidden_states): - forward_function = cohere_attention_forward_quantized - else: - forward_function = cohere_attention_forward_origin - return forward_function( - self=self, - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - **kwargs, - ) - - -def cohere_attention_forward_quantized( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - **kwargs, -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) - if self.use_qk_norm: - query_states = self.q_norm(query_states) - key_states = self.k_norm(key_states) - - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.view(bsz, q_len, - self.num_key_value_heads, self.head_dim).transpose(1, 2) - - past_key_value = getattr(self, "past_key_value", past_key_value) - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, position_ids) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; position_ids needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, - cache_kwargs, new_layout=True) - if q_len == 1 and query_states.device.type == 'xpu' and not self.training \ - and not hidden_states.requires_grad: - import xe_addons - attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states, attention_mask) - attn_weights = None - else: - key_states, value_states = restore_fp8_kv_cache(key_states, - value_states, query_states.dtype) - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - attn_weights = torch.matmul(query_states, - key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, - dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, - p=self.attention_dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - - invalidInputError(attn_output.size() == (bsz, self.num_heads, q_len, self.head_dim), - "`attn_output` should be of size " - f"{(bsz, self.num_heads, q_len, self.head_dim)}," - f" but is {attn_output.size()}") - - attn_output = attn_output.transpose(1, 2).contiguous() - - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -def cohere_attention_forward_origin( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - device = hidden_states.device - use_fuse_rope = should_use_fuse_rope(hidden_states, position_ids, self.training) - enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx) - decoding_fast_path = use_decoding_fast_path(self.q_proj, - use_fuse_rope, - enough_kv_room, - bsz * q_len) - if decoding_fast_path: - hidden_states = hidden_states.view(1, -1) - cache_k = past_key_value.key_cache[self.layer_idx] - cache_v = past_key_value.value_cache[self.layer_idx] - kv_seq_len = cache_k.shape[-2] - import xe_linear - query_states, key_states, value_states = xe_linear.forward_qkv(hidden_states, - self.q_proj.weight, - self.k_proj.weight, - self.v_proj.weight, - position_ids, - cache_k, cache_v, - self.q_proj.weight.qtype, - self.v_proj.weight.qtype, - kv_seq_len, - self.head_dim, - self.rotary_emb.base,) - kv_seq_len += 1 - # update past_key_value's seem_tokens and kv caches. - if self.layer_idx == 0: - past_key_value._seen_tokens = kv_seq_len - past_key_value.key_cache[self.layer_idx] = key_states - past_key_value.value_cache[self.layer_idx] = value_states - else: - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) - if self.use_qk_norm: - query_states = self.q_norm(query_states) - key_states = self.k_norm(key_states) - - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, - self.head_dim).transpose(1, 2) - - past_key_value = getattr(self, "past_key_value", past_key_value) - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - if self.layer_idx is None: - invalidInputError( - False, - "The cache structure has changed since version v4.36. " - f"If you are using {self.__class__.__name__} " - "for auto-regressive decoding with k/v caching, " - "please make sure to initialize the attention class with a layer index." - ) - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, position_ids) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - if self.layer_idx == 0: - past_key_value._seen_tokens += key_states.shape[-2] - - if len(past_key_value.key_cache) <= self.layer_idx: - past_key_value.key_cache.append(key_states) - past_key_value.value_cache.append(value_states) - else: - cache_k = past_key_value.key_cache[self.layer_idx] - cache_v = past_key_value.value_cache[self.layer_idx] - - if not enough_kv_room: - # allocate new - new_c_k, new_c_v = extend_kv_cache(bsz, - self.num_key_value_heads, # Support GQA - self.head_dim, - cache_k.size(2), - kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH, - dtype=cache_k.dtype, - device=device) - - new_c_k[:] = cache_k - new_c_v[:] = cache_v - cache_k = new_c_k - cache_v = new_c_v - - key_states, value_states = append_kv_cache(cache_k, - cache_v, - key_states, - value_states) - - # update past_key_value - past_key_value.key_cache[self.layer_idx] = key_states - past_key_value.value_cache[self.layer_idx] = value_states - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - if not self.training and not hidden_states.requires_grad and \ - use_flash_attention(query_states, key_states, attention_mask): - attn_output = F.scaled_dot_product_attention(query_states.to(device, dtype=torch.float16), - key_states.to(device, dtype=torch.float16), - value_states.to(device, dtype=torch.float16), - is_causal=True) - attn_weights = None - elif not self.training and not hidden_states.requires_grad and \ - use_sdp(q_len, key_states.shape[2], self.head_dim, query_states): - import xe_addons - if attention_mask is not None: - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - else: - causal_mask = None - attn_output = xe_addons.sdp(query_states, key_states, value_states, causal_mask) - attn_output = attn_output.view(query_states.shape) - attn_weights = None - else: - attn_weights = torch.matmul(query_states, - key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, - dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, - p=self.attention_dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - - invalidInputError(attn_output.size() == (bsz, self.num_heads, q_len, self.head_dim), - "`attn_output` should be of size " - f"{(bsz, self.num_heads, q_len, self.head_dim)}," - f" but is {attn_output.size()}") - - attn_output = attn_output.transpose(1, 2).contiguous() - - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output.to(hidden_states.dtype), attn_weights, past_key_value diff --git a/python/llm/src/ipex_llm/transformers/models/minicpmv.py b/python/llm/src/ipex_llm/transformers/models/minicpmv.py index 6bfbf460..9e0f1085 100644 --- a/python/llm/src/ipex_llm/transformers/models/minicpmv.py +++ b/python/llm/src/ipex_llm/transformers/models/minicpmv.py @@ -53,10 +53,10 @@ def siglip_attention_forward( qkv = qkv.transpose(1, 2) query_states, key_states, value_states = qkv.chunk(3, dim=1) - from ipex_llm.transformers.utils import get_xpu_device_type + from ipex_llm.transformers.utils import get_xpu_device_name if ( self.head_dim == 72 - and get_xpu_device_type(query_states) in ["arc", "flex"] and + and get_xpu_device_name(query_states.device) == "arc" and query_states.dtype in [torch.float, torch.half] ): n_heads, kv_length = query_states.size(1), key_states.size(2) diff --git a/python/llm/src/ipex_llm/transformers/models/mixtral.py b/python/llm/src/ipex_llm/transformers/models/mixtral.py deleted file mode 100644 index 6a98899a..00000000 --- a/python/llm/src/ipex_llm/transformers/models/mixtral.py +++ /dev/null @@ -1,576 +0,0 @@ -# -# Copyright 2016 The BigDL Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# Some parts of this file is adapted from -# https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py - -# coding=utf-8 -# Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" PyTorch Mixtral model.""" -import math -from typing import Optional, Tuple, Union, List -from transformers.modeling_outputs import MoeModelOutputWithPast -from transformers.cache_utils import Cache, DynamicCache -from transformers.modeling_attn_mask_utils import ( - _prepare_4d_causal_attention_mask, -) - -import torch -from torch import nn -import torch.nn.functional as F -from ipex_llm.ggml.quantize import ggml_tensor_qtype -from ipex_llm.utils.common import invalidInputError -from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache -from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, is_enough_kv_cache_room_4_36 -from ipex_llm.transformers.models.utils import should_use_fuse_rope -from ipex_llm.transformers.models.utils import use_decoding_fast_path -from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp -from ipex_llm.transformers.models.utils import mlp_fusion_check, SILU -from ipex_llm.transformers.low_bit_linear import IQ2_XXS - -import os - -KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)) - - -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). - The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) - to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, - n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - -def mixtral_moeblock_forward(self, - hidden_states: torch.Tensor): - batch_size, sequence_length, hidden_dim = hidden_states.shape - hidden_states = hidden_states.view(-1, hidden_dim) - bs = hidden_states.shape[0] - # router_logits: (batch * sequence_length, n_experts) - router_logits = self.gate(hidden_states) - - routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) - routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) - routing_weights /= routing_weights.sum(dim=-1, keepdim=True) - # we cast back to the input dtype - routing_weights = routing_weights.to(hidden_states.dtype) - - if bs == 1: - selected_experts = selected_experts[0].cpu().tolist() - for idx in range(self.top_k): - exp_id = selected_experts[idx] - expert_layer = self.experts[exp_id] - weight = routing_weights[:, idx] - if idx == 0: - final_hidden_states = expert_layer(hidden_states, weight) - else: - final_hidden_states = final_hidden_states + expert_layer(hidden_states, weight) - elif bs < 256 and hidden_states.device.type == 'xpu': - final_hidden_states = torch.zeros((batch_size * sequence_length, hidden_dim), - dtype=hidden_states.dtype, device=hidden_states.device) - import xe_linear - indexes = xe_linear.get_moe_indexes(selected_experts.to(torch.int32).cpu(), 8) - for expert_idx in range(self.num_experts): - expert_layer = self.experts[expert_idx] - idx_list = indexes[0][expert_idx] - top_x_list = indexes[1][expert_idx] - if len(idx_list) == 0: - continue - - top_x = torch.tensor(top_x_list, dtype=torch.long, device=hidden_states.device) - current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim) - current_hidden_states = expert_layer(current_state, - routing_weights[top_x_list, idx_list, None]) - final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) - else: - final_hidden_states = torch.zeros( - (batch_size * sequence_length, hidden_dim), - dtype=hidden_states.dtype, - device=hidden_states.device - ) - # One hot encode the selected experts to create an expert mask - # this will be used to easily index which expert is going to be sollicitated - expert_mask = torch.nn.functional.one_hot(selected_experts, - num_classes=self.num_experts).permute(2, 1, 0) - - # Loop over all available experts in the model and perform the computation on each expert - for expert_idx in range(self.num_experts): - expert_layer = self.experts[expert_idx] - idx, top_x = torch.where(expert_mask[expert_idx]) - - if top_x.shape[0] == 0: - continue - - # in torch it is faster to index using lists than torch tensors - top_x_list = top_x.tolist() - idx_list = idx.tolist() - - # Index the correct hidden states and compute the expert hidden state for - # the current expert. We need to make sure to multiply the output hidden - # states by `routing_weights` on the corresponding tokens (top-1 and top-2) - current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim) - current_hidden_states = expert_layer(current_state, - routing_weights[top_x_list, idx_list, None]) - - # However `index_add_` only support torch tensors for indexing so we'll use - # the `top_x` tensor here. - final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) - - final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) - return final_hidden_states, router_logits - - -def mixtral_attention_forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor]=None, - position_ids: Optional[torch.LongTensor]=None, - past_key_value: Optional[Tuple[torch.Tensor]]=None, - output_attentions: bool=False, - use_cache: bool=False, - padding_mask: Optional[torch.Tensor]=None, -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - device = hidden_states.device - # for flash attention - original_dtype = hidden_states.dtype - - use_fuse_rope = should_use_fuse_rope(hidden_states, position_ids, self.training) - enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx) - decoding_fast_path = use_decoding_fast_path(self.q_proj, - use_fuse_rope, - enough_kv_room, - bsz * q_len) - - if decoding_fast_path: - hidden_states = hidden_states.view(1, -1) - cache_k = past_key_value.key_cache[self.layer_idx] - cache_v = past_key_value.value_cache[self.layer_idx] - kv_seq_len = cache_k.shape[-2] - import xe_linear - query_states, key_states, value_states = xe_linear.forward_qkv(hidden_states, - self.q_proj.weight, - self.k_proj.weight, - self.v_proj.weight, - position_ids, - cache_k, cache_v, - self.q_proj.weight.qtype, - self.v_proj.weight.qtype, - kv_seq_len, - self.head_dim, - self.rotary_emb.base,) - kv_seq_len += 1 - # update past_key_value's seem_tokens and kv caches. - if self.layer_idx == 0: - past_key_value.seen_tokens = kv_seq_len - past_key_value.key_cache[self.layer_idx] = key_states - past_key_value.value_cache[self.layer_idx] = value_states - # diasble it for now as it will cause output change for unknown reason - # elif decoding_fast_path and self.q_proj.qtype == IQ2_XXS: - # # this path self.v_proj use q4_0 - # hidden_states = hidden_states.view(1, -1) - # cache_k = past_key_value.key_cache[self.layer_idx] - # cache_v = past_key_value.value_cache[self.layer_idx] - # kv_seq_len = cache_k.shape[-2] - # import xe_linear - # query_states, key_states = xe_linear.forward_qk(hidden_states, - # self.q_proj.weight, - # self.k_proj.weight, - # position_ids, - # cache_k, - # self.q_proj.weight.qtype, - # kv_seq_len, - # self.head_dim, - # 10000) - # kv_seq_len += 1 - # # update past_key_value's seem_tokens and kv caches. - # if self.layer_idx == 0: - # past_key_value.seen_tokens = kv_seq_len - # # update value_states - # value_states = self.v_proj(hidden_states) - # value_states = value_states.view(bsz, q_len, - # self.num_key_value_heads, self.head_dim).transpose(1, 2) - # new_size = (cache_v.size(0), - # cache_v.size(1), - # cache_v.size(2) + value_states.size(2), - # cache_v.size(3)) - # new_cache_v = cache_v.as_strided(new_size, cache_v.stride(), storage_offset=0) - # new_cache_v[:, :, cache_v.size(2):cache_v.size(2)+value_states.size(2), :] = value_states - - # past_key_value.key_cache[self.layer_idx] = key_states - # past_key_value.value_cache[self.layer_idx] = new_cache_v - else: - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, - self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, - self.num_key_value_heads, self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - if self.layer_idx is None: - invalidInputError(False, - "The cache structure has changed since version v4.36. " - f"If you are using {self.__class__.__name__} for " - "auto-regressive decodingwith k/v caching, please make sure " - "to initialize the attention class with a layer index.") - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - - if use_fuse_rope: - import xe_addons - xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids, - query_states, key_states) - else: - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, - cos, sin, position_ids, "mixtral") - - if past_key_value is not None: - # update the number of seen tokens - if self.layer_idx == 0: - past_key_value.seen_tokens += key_states.shape[-2] - - # reuse k, v, self_attention - # update `past_key_value` with `key_states` and `value_states` for layer `layer_idx` - if len(past_key_value.key_cache) <= self.layer_idx: - past_key_value.key_cache.append(key_states) - past_key_value.value_cache.append(value_states) - else: - cache_k = past_key_value.key_cache[self.layer_idx] - cache_v = past_key_value.value_cache[self.layer_idx] - - if not enough_kv_room: - # allocate new - new_c_k, new_c_v = extend_kv_cache(bsz, - self.num_key_value_heads, # Support GQA - self.head_dim, - cache_k.size(2), - kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH, - dtype=cache_k.dtype, - device=device) - - new_c_k[:] = cache_k - new_c_v[:] = cache_v - cache_k = new_c_k - cache_v = new_c_v - - key_states, value_states = append_kv_cache(cache_k, - cache_v, - key_states, - value_states) - - # update past_key_value - past_key_value.key_cache[self.layer_idx] = key_states - past_key_value.value_cache[self.layer_idx] = value_states - - if not self.training and not hidden_states.requires_grad: - fsdp_flag = use_flash_attention(query_states, key_states) - else: - fsdp_flag = False - if fsdp_flag: - attention_dtype = torch.float16 # use fp16 for flash attention - else: - attention_dtype = original_dtype - - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups).to(device, - dtype=attention_dtype) - value_states = repeat_kv(value_states, self.num_key_value_groups).to(device, - dtype=attention_dtype) - - if fsdp_flag: - attn_output = F.scaled_dot_product_attention(query_states.to(dtype=attention_dtype), - key_states, - value_states, - is_causal=True) - attn_weights = None - elif use_sdp(query_states.shape[2], key_states.shape[2], self.head_dim, query_states): - import xe_addons - attn_output = xe_addons.sdp(query_states, key_states, value_states, attention_mask) - attn_output = attn_output.view(query_states.shape) - attn_weights = None - else: - attn_weights = torch.matmul( - query_states.to(key_states.dtype), - key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - invalidInputError( - False, - f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}," - f" but is {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - invalidInputError( - False, - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}," - f" but is {attention_mask.size()}" - ) - - attn_weights = attn_weights + attention_mask - - # upcast attention to fp32 - attn_weights = nn.functional.\ - softmax(attn_weights, dim=-1, dtype=torch.float32).to(value_states.dtype) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - invalidInputError( - False, - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}," - f" but is {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -def mixtral_mlp_forward( - self, - x: torch.Tensor, - routing_weights -) -> torch.Tensor: - qtype = getattr(self.w1, "qtype", None) - if mlp_fusion_check(x, qtype, self.training): - import xe_linear - return self.w2(xe_linear.mlp_forward_xpu( - x, self.w1.weight.data, self.w3.weight.data, - x.shape[0], x.shape[1], self.w1.out_len, - SILU, qtype, - )) * routing_weights - else: - current_hidden_states = self.act_fn(self.w1(x)) * self.w3(x) - current_hidden_states = self.w2(current_hidden_states) - return routing_weights * current_hidden_states - - -def mixtral_model_forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - output_router_logits: Optional[bool] = None, - return_dict: Optional[bool] = None, -) -> Union[Tuple, MoeModelOutputWithPast]: - # to be compatible with transformers>=4.37.0 - self._use_flash_attention_2 = self.config._attn_implementation == "flash_attention_2" - - output_attentions = output_attentions if output_attentions is not None \ - else self.config.output_attentions - output_router_logits = ( - output_router_logits if output_router_logits is not None - else self.config.output_router_logits - ) - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else - self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - invalidInputError(False, "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") # noqa - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - invalidInputError(False, "You have to specify either decoder_input_ids or decoder_inputs_embeds") # noqa - - past_key_values_length = 0 - - if use_cache: - use_legacy_cache = not isinstance(past_key_values, Cache) - if use_legacy_cache: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_key_values_length = past_key_values.get_usable_length(seq_length) - - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, - dtype=torch.long, device=device - ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - - if attention_mask is not None and self._use_flash_attention_2 and use_cache: - is_padding_right = attention_mask[:, -1].sum().item() != batch_size - if is_padding_right: - invalidInputError( - False, - "You are attempting to perform batched generation with padding_side='right'" - " this may lead to unexpected behaviour for Flash Attention version of Mixtral. Make sure to " # noqa - " call `tokenizer.padding_side = 'left'` before tokenizing the input. " - ) - - if self._use_flash_attention_2: - # 2d mask is passed through the layers - attention_mask = attention_mask \ - if (attention_mask is not None and 0 in attention_mask) else None - else: - # 4d mask is passed through the layers - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, - sliding_window=self.config.sliding_window, - ) - - hidden_states = inputs_embeds - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." # noqa - ) - use_cache = False - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - all_router_logits = () if output_router_logits else None - next_decoder_cache = None - - for decoder_layer in self.layers: - if output_hidden_states: - all_hidden_states += (hidden_states,) - - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - attention_mask, - position_ids, - past_key_values, - output_attentions, - output_router_logits, - use_cache, - ) - else: - # bigdl-llm changes: - # - # Avoid moving `attention_mask`` and `position_ids`` to other devices multiple times. - # - # When the model is partitioned on two different devices using - # `accelerate`'s `dispatch``, a hook to move inputs to the correct device is - # added to each layer's `forward``, which will result in moving `attention_mask` - # and `position_ids`, which allocated on device:0, to other devices for each - # decoder layer not in device:0. - # - # To avoid this, we move `attention_mask` and `position_ids` to the device of - # the current layer before the forward call, so that the moving is only done once - # for each devices other than devie:0. - # - curr_device = decoder_layer.input_layernorm.weight.device - if attention_mask is not None: - attention_mask = attention_mask.to(curr_device) - if position_ids is not None: - position_ids = position_ids.to(curr_device) - # bigdl-llm changes end - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - output_router_logits=output_router_logits, - use_cache=use_cache, - ) - - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - if output_router_logits: - all_router_logits += (layer_outputs[-1],) - - hidden_states = self.norm(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - next_cache = None - if use_cache: - next_cache = next_decoder_cache.to_legacy_cache() \ - if use_legacy_cache else next_decoder_cache - - if not return_dict: - return tuple( - v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits] # noqa - if v is not None - ) - return MoeModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - router_logits=all_router_logits, - ) diff --git a/python/llm/src/ipex_llm/transformers/models/sd.py b/python/llm/src/ipex_llm/transformers/models/sd.py index 06109f15..854249e4 100644 --- a/python/llm/src/ipex_llm/transformers/models/sd.py +++ b/python/llm/src/ipex_llm/transformers/models/sd.py @@ -36,7 +36,7 @@ import math import torch from typing import Optional -from ipex_llm.transformers.utils import get_xpu_device_type +from ipex_llm.transformers.utils import get_xpu_device_name from ipex_llm.transformers.models.common import padding_qkv_hd from ipex_llm.transformers.models.common import scaled_dot_product_attention from diffusers.models.attention_processor import Attention @@ -144,7 +144,7 @@ class AttnProcessor2_0: def upcast_vae(self): # workaround overflow and ipex's bugs - if get_xpu_device_type(self.vae.post_quant_conv.weight) in ["arc", "flex", "pvc"]: + if get_xpu_device_name(self.vae.post_quant_conv.weight.device) == "arc": self.vae.to(torch.bfloat16) else: self.vae.decoder.up_blocks.to(torch.bfloat16) diff --git a/python/llm/src/ipex_llm/transformers/models/utils.py b/python/llm/src/ipex_llm/transformers/models/utils.py index 0c6f6208..cd16b71b 100644 --- a/python/llm/src/ipex_llm/transformers/models/utils.py +++ b/python/llm/src/ipex_llm/transformers/models/utils.py @@ -19,7 +19,7 @@ import torch import warnings from ipex_llm.utils.common import invalidInputError from ipex_llm.ggml.quantize import ggml_tensor_qtype -from ipex_llm.transformers.utils import get_ipex_version, get_xpu_device_type +from ipex_llm.transformers.utils import get_ipex_version, get_xpu_device_name from ipex_llm.transformers.low_bit_linear import SYM_INT4, SYM_INT8, FP8E5, IQ2_XXS, FP4, FP8E4,\ FP6, ASYM_INT4 @@ -85,16 +85,14 @@ def use_quantize_kv_cache(linear: torch.nn.Module, x: torch.Tensor, kv_group: in return os.environ["IPEX_LLM_QUANTIZE_KV_CACHE"] == "1" elif os.environ.get("IPEX_LLM_LOW_MEM", None) is not None: return os.environ["IPEX_LLM_LOW_MEM"] == "1" + elif linear.qtype in [ggml_tensor_qtype["fp16"], ggml_tensor_qtype["bf16"]]: + return False else: - return x.device.type == 'xpu' and kv_cache_device_check(x, kv_group) \ - and hasattr(linear, "qtype") and \ - linear.qtype != ggml_tensor_qtype["fp16"] and linear.qtype != ggml_tensor_qtype["bf16"] - - -def kv_cache_device_check(x: torch.Tensor, kv_group: int) -> bool: - return (get_xpu_device_type(x) in ["mtl", "lnl"] and kv_group <= 1) or \ - ((get_xpu_device_type(x) == "arc" or get_xpu_device_type(x) == "flex") and - 1 < x.size(0) and x.size(0) <= 8) + device_name = get_xpu_device_name(x.device) + return ( + device_name in ["mtl", "lnl", "arl"] and kv_group == 1 + or device_name in ["arc", "bmg"] and x.size(0) > 1 + ) def init_fp8_kv_cache(batch_size, num_heads, current_length, head_dim, device): @@ -226,57 +224,6 @@ def is_enough_kv_cache_room_4_31(past_key_value, seq_len=1): (past_key_value[0].size(2) + seq_len) * past_key_value[0].size(3) -def use_flash_attention(query, key, attention_mask=None): - # here we support query's shape is always [batch_size, head_num, q_len, head_dim], - # key's shape is always [batch_size, head_num, k_len, head_dim] - invalidInputError(query.dim() == 4, - "Here query input of use_flash_attention should be [batch_size, " - "head_num, q_len, head_dim]") - invalidInputError(key.dim() == 4, - "Here key input of use_flash_attention should be [batch_size, " - "head_num, k_len, head_dim]") - bsz, _, q_len, _ = query.size() - k_len = key.size()[2] - # check whether ipex flash attention can be used - if q_len != k_len: - # now only use flash attention for first token - # as it seems have no performance benifit for rest token now - return False - if query.device.type != "xpu": - # ipex flash attention only support for xpu - return False - ipex_version = get_ipex_version() - if ipex_version <= "2.0.110+xpu": - # ipex flash attention is supported from ipex 2.1 - return False - if not torch.xpu.has_xetla(): - # ipex flash attention is only supported for xetla - # may update this later - return False - elif get_xpu_device_type(query) != "pvc": - return False - if query.dtype not in [torch.float32, torch.float16]: - # only use flash attention for fp32/fp16 input - return False - if bsz > 1: - # as flash attention doesn't support attn_mask in ipex 2.1, - # so it will cause output error for padded batch input - if attention_mask is None: - return True - else: - # TODO: below logic may change for different model - # attention mask shape : [bsz, 1, q_len, k_len] - if attention_mask[0].squeeze()[0, 0].item() != 0: - # first batch contains padding - # otherwise we suppose it should be a upper triangular matrix - # at the same time, the diagonal is also 0 - return False - elif not attention_mask.equal(attention_mask[0].repeat(bsz, 1, 1, 1)): - # check whether mask of every batch is the same - return False - return True - - def use_sdp(q_len, kv_len, head_dim, query_states): return ( query_states.device.type == "xpu" @@ -315,38 +262,16 @@ def mlp_fusion_check(x, qtype, training): if training or x.requires_grad: return False if qtype == FP6: - device = get_xpu_device_type(x) - if device in ["mtl", "lnl"]: + device = get_xpu_device_name(x.device) + if device in ["mtl", "lnl", "arl"]: return False return True -def use_decoding_fast_path(proj, - use_fuse_rope, - enough_kv_room, - bs, - qtype_check=decoding_fast_path_qtype_check): - if proj is None: - return False - device = get_xpu_device_type(proj.weight) - if not qtype_check(proj): - return False - if not use_fuse_rope: - return False - if not enough_kv_room: - return False - if bs != 1: - return False - - if device in ["uhd"]: - return False - return True - - def use_xmx(x: torch.Tensor, qtype: int): - device = get_xpu_device_type(x) + device = get_xpu_device_name(x.device) return ( - device in ["arc", "flex", "pvc"] + device in ["arc", "pvc"] and qtype in [SYM_INT4, SYM_INT8, FP8E4, FP8E5] and ( (device == "pvc" and 1 < x.size(0) <= 16) @@ -370,7 +295,7 @@ def fp16_fusion_check(proj, x, training): return False if x.requires_grad: return False - device_type = get_xpu_device_type(x) + device_type = get_xpu_device_name(x.device) if device_type != "pvc": return False return True @@ -439,7 +364,7 @@ def should_use_compresskv(x: torch.Tensor, prompt_len: int): else: if use_compress_kv is None: return ( - get_xpu_device_type(x) in ["mtl", "lnl"] + get_xpu_device_name(x.device) in ["mtl", "lnl", "arl"] and prompt_len >= 1800 and prompt_len <= 4500 ) diff --git a/python/llm/src/ipex_llm/transformers/utils.py b/python/llm/src/ipex_llm/transformers/utils.py index b2ae0ca3..e86215e1 100644 --- a/python/llm/src/ipex_llm/transformers/utils.py +++ b/python/llm/src/ipex_llm/transformers/utils.py @@ -168,27 +168,12 @@ def get_ipex_version(): return _ipex_version -def get_xpu_device_type(x): - if x.device.type != "xpu": - return x.device.type - name = torch.xpu.get_device_name(x.device.index) - if name.startswith("Intel(R) Arc(TM) A"): - return "arc" - elif name.startswith("Intel(R) Graphics [0xe20b]"): - return "bmg" - elif name.startswith("Intel(R) Arc(TM)"): - if 'V' in name: - return "lnl" - else: - return "mtl" - elif name.startswith("Intel(R) Data Center GPU Flex"): - return "flex" - elif name.startswith("Intel(R) Data Center GPU Max"): - return "pvc" - elif name.startswith("Intel(R) UHD"): - return "uhd" +def get_xpu_device_name(device: torch.device): + if device.type != "xpu": + return device.type else: - return "others" + import xe_linear + return xe_linear.get_xpu_device_name(device) def load_imatrix_data(imatrix_file):