LLM: Optimize cohere model (#10878)
* use mlp and rms * optimize kv_cache * add fuse qkv * add flash attention and fp16 sdp * error fp8 sdp * fix optimized * fix style * update * add for pp
This commit is contained in:
		
							parent
							
								
									13a44cdacb
								
							
						
					
					
						commit
						191b184341
					
				
					 2 changed files with 475 additions and 0 deletions
				
			
		| 
						 | 
					@ -1282,6 +1282,24 @@ def _optimize_post(model, lightweight_bmm=False):
 | 
				
			||||||
        convert_forward(model,
 | 
					        convert_forward(model,
 | 
				
			||||||
                        module.Qwen2MoeAttention,
 | 
					                        module.Qwen2MoeAttention,
 | 
				
			||||||
                        qwen2moe_attention_forward)
 | 
					                        qwen2moe_attention_forward)
 | 
				
			||||||
 | 
					    elif model.config.model_type == "cohere":
 | 
				
			||||||
 | 
					        # for CohereForAI/c4ai-command-r-v01
 | 
				
			||||||
 | 
					        modeling_module_name = model.__class__.__module__
 | 
				
			||||||
 | 
					        module = importlib.import_module(modeling_module_name)
 | 
				
			||||||
 | 
					        from ipex_llm.transformers.models.cohere import cohere_attention_forward
 | 
				
			||||||
 | 
					        from ipex_llm.transformers.models.cohere import cohere_model_forward
 | 
				
			||||||
 | 
					        convert_forward(model,
 | 
				
			||||||
 | 
					                        module.CohereModel,
 | 
				
			||||||
 | 
					                        cohere_model_forward)
 | 
				
			||||||
 | 
					        convert_forward(model,
 | 
				
			||||||
 | 
					                        module.CohereAttention,
 | 
				
			||||||
 | 
					                        cohere_attention_forward)
 | 
				
			||||||
 | 
					        convert_forward(model,
 | 
				
			||||||
 | 
					                        module.CohereLayerNorm,
 | 
				
			||||||
 | 
					                        llama_rms_norm_forward)
 | 
				
			||||||
 | 
					        convert_forward(model,
 | 
				
			||||||
 | 
					                        module.CohereMLP,
 | 
				
			||||||
 | 
					                        llama_mlp_forward)
 | 
				
			||||||
    elif model.config.model_type == "aquila":
 | 
					    elif model.config.model_type == "aquila":
 | 
				
			||||||
        modeling_module_name = model.__class__.__module__
 | 
					        modeling_module_name = model.__class__.__module__
 | 
				
			||||||
        module = importlib.import_module(modeling_module_name)
 | 
					        module = importlib.import_module(modeling_module_name)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										457
									
								
								python/llm/src/ipex_llm/transformers/models/cohere.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										457
									
								
								python/llm/src/ipex_llm/transformers/models/cohere.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,457 @@
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Copyright 2016 The BigDL Authors.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Some parts of this file is adapted from
 | 
				
			||||||
 | 
					# https://github.com/huggingface/transformers/blob/main/src/transformers/models/cohere/modeling_cohere.py
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# coding=utf-8
 | 
				
			||||||
 | 
					# Copyright 2024 Cohere team. All rights reserved.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
 | 
				
			||||||
 | 
					# and OPT implementations in this library. It has been modified from its
 | 
				
			||||||
 | 
					# original forms to accommodate minor architectural differences compared
 | 
				
			||||||
 | 
					# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# This file is based on the LLama model definition file in transformers
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					"""PyTorch Cohere model."""
 | 
				
			||||||
 | 
					import math
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					import torch.nn.functional as F
 | 
				
			||||||
 | 
					import torch.nn as nn
 | 
				
			||||||
 | 
					import torch.utils.checkpoint
 | 
				
			||||||
 | 
					from typing import Optional, Tuple, List
 | 
				
			||||||
 | 
					from ipex_llm.transformers.models.llama import repeat_kv
 | 
				
			||||||
 | 
					from ipex_llm.transformers.models.utils import extend_kv_cache, append_kv_cache
 | 
				
			||||||
 | 
					from transformers.models.cohere.modeling_cohere import apply_rotary_pos_emb
 | 
				
			||||||
 | 
					from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_36
 | 
				
			||||||
 | 
					from ipex_llm.transformers.models.utils import use_decoding_fast_path
 | 
				
			||||||
 | 
					from ipex_llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
 | 
				
			||||||
 | 
					from transformers.models.cohere.modeling_cohere import apply_rotary_pos_emb
 | 
				
			||||||
 | 
					from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
 | 
				
			||||||
 | 
					from ipex_llm.transformers.kv import DynamicFp8Cache
 | 
				
			||||||
 | 
					from ipex_llm.transformers.models.qwen2 import should_use_fuse_rope
 | 
				
			||||||
 | 
					from transformers.modeling_outputs import BaseModelOutputWithPast
 | 
				
			||||||
 | 
					from ipex_llm.utils.common import invalidInputError
 | 
				
			||||||
 | 
					try:
 | 
				
			||||||
 | 
					    from transformers.cache_utils import Cache, DynamicCache
 | 
				
			||||||
 | 
					except ImportError:
 | 
				
			||||||
 | 
					    Cache = Tuple[torch.Tensor]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					KV_CACHE_ALLOC_BLOCK_LENGTH = 256
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def cohere_model_forward(
 | 
				
			||||||
 | 
					    self,
 | 
				
			||||||
 | 
					    input_ids: torch.LongTensor = None,
 | 
				
			||||||
 | 
					    attention_mask: Optional[torch.Tensor] = None,
 | 
				
			||||||
 | 
					    position_ids: Optional[torch.LongTensor] = None,
 | 
				
			||||||
 | 
					    past_key_values: Optional[List[torch.FloatTensor]] = None,
 | 
				
			||||||
 | 
					    inputs_embeds: Optional[torch.FloatTensor] = None,
 | 
				
			||||||
 | 
					    use_cache: Optional[bool] = None,
 | 
				
			||||||
 | 
					    output_attentions: Optional[bool] = None,
 | 
				
			||||||
 | 
					    output_hidden_states: Optional[bool] = None,
 | 
				
			||||||
 | 
					    return_dict: Optional[bool] = None,
 | 
				
			||||||
 | 
					    cache_position: Optional[torch.LongTensor] = None,
 | 
				
			||||||
 | 
					):
 | 
				
			||||||
 | 
					    use_cache = use_cache if use_cache is not None \
 | 
				
			||||||
 | 
					        else self.config.use_cache
 | 
				
			||||||
 | 
					    if use_cache and use_quantize_kv_cache(self.layers[0].mlp.up_proj, input_ids):
 | 
				
			||||||
 | 
					        if not isinstance(past_key_values, DynamicFp8Cache):
 | 
				
			||||||
 | 
					            past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
 | 
				
			||||||
 | 
					    output_attentions = output_attentions if output_attentions is not None \
 | 
				
			||||||
 | 
					        else self.config.output_attentions
 | 
				
			||||||
 | 
					    output_hidden_states = (
 | 
				
			||||||
 | 
					        output_hidden_states if output_hidden_states is not None
 | 
				
			||||||
 | 
					        else self.config.output_hidden_states
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    use_cache = use_cache if use_cache is not None else self.config.use_cache
 | 
				
			||||||
 | 
					    return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if input_ids is not None and inputs_embeds is not None:
 | 
				
			||||||
 | 
					        invalidInputError(False,
 | 
				
			||||||
 | 
					                          "You cannot specify both input_ids and inputs_embeds at the same time")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if self.gradient_checkpointing and self.training and use_cache:
 | 
				
			||||||
 | 
					        invalidInputError(False,
 | 
				
			||||||
 | 
					                          "`use_cache=True` is incompatible "
 | 
				
			||||||
 | 
					                          "with gradient checkpointing. Setting `use_cache=False`.")
 | 
				
			||||||
 | 
					        use_cache = False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if inputs_embeds is None:
 | 
				
			||||||
 | 
					        inputs_embeds = self.embed_tokens(input_ids)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    past_seen_tokens = 0
 | 
				
			||||||
 | 
					    if use_cache:  # kept for BC (cache positions)
 | 
				
			||||||
 | 
					        if not isinstance(past_key_values, Cache):
 | 
				
			||||||
 | 
					            past_key_values = DynamicCache.from_legacy_cache(past_key_values)
 | 
				
			||||||
 | 
					        past_seen_tokens = past_key_values.get_seq_length()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if cache_position is None:
 | 
				
			||||||
 | 
					        if isinstance(past_key_values, Cache):
 | 
				
			||||||
 | 
					            invalidInputError(False, "cache_position is a required argument when using Cache.")
 | 
				
			||||||
 | 
					        cache_position = torch.arange(
 | 
				
			||||||
 | 
					            past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if position_ids is None:
 | 
				
			||||||
 | 
					        position_ids = cache_position.unsqueeze(0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    causal_mask = self._update_causal_mask(attention_mask,
 | 
				
			||||||
 | 
					                                           inputs_embeds, cache_position, past_seen_tokens)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # embed positions
 | 
				
			||||||
 | 
					    hidden_states = inputs_embeds
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # decoder layers
 | 
				
			||||||
 | 
					    all_hidden_states = () if output_hidden_states else None
 | 
				
			||||||
 | 
					    all_self_attns = () if output_attentions else None
 | 
				
			||||||
 | 
					    next_decoder_cache = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    for decoder_layer in self.layers:
 | 
				
			||||||
 | 
					        if output_hidden_states:
 | 
				
			||||||
 | 
					            all_hidden_states += (hidden_states,)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.gradient_checkpointing and self.training:
 | 
				
			||||||
 | 
					            layer_outputs = self._gradient_checkpointing_func(
 | 
				
			||||||
 | 
					                decoder_layer.__call__,
 | 
				
			||||||
 | 
					                hidden_states,
 | 
				
			||||||
 | 
					                causal_mask,
 | 
				
			||||||
 | 
					                position_ids,
 | 
				
			||||||
 | 
					                past_key_values,
 | 
				
			||||||
 | 
					                output_attentions,
 | 
				
			||||||
 | 
					                use_cache,
 | 
				
			||||||
 | 
					                cache_position,
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            # ipex-llm changes
 | 
				
			||||||
 | 
					            curr_device = decoder_layer.input_layernorm.weight.device
 | 
				
			||||||
 | 
					            if causal_mask is not None:
 | 
				
			||||||
 | 
					                causal_mask = causal_mask.to(curr_device)
 | 
				
			||||||
 | 
					            if position_ids is not None:
 | 
				
			||||||
 | 
					                position_ids = position_ids.to(curr_device)
 | 
				
			||||||
 | 
					            # ipex-llm changes end
 | 
				
			||||||
 | 
					            layer_outputs = decoder_layer(
 | 
				
			||||||
 | 
					                hidden_states,
 | 
				
			||||||
 | 
					                attention_mask=causal_mask,
 | 
				
			||||||
 | 
					                position_ids=position_ids,
 | 
				
			||||||
 | 
					                past_key_value=past_key_values,
 | 
				
			||||||
 | 
					                output_attentions=output_attentions,
 | 
				
			||||||
 | 
					                use_cache=use_cache,
 | 
				
			||||||
 | 
					                cache_position=cache_position,
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        hidden_states = layer_outputs[0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if use_cache:
 | 
				
			||||||
 | 
					            next_decoder_cache = layer_outputs[2 if output_attentions else 1]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if output_attentions:
 | 
				
			||||||
 | 
					            all_self_attns += (layer_outputs[1],)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    hidden_states = self.norm(hidden_states)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # add hidden states from the last decoder layer
 | 
				
			||||||
 | 
					    if output_hidden_states:
 | 
				
			||||||
 | 
					        all_hidden_states += (hidden_states,)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    next_cache = next_decoder_cache if use_cache else None
 | 
				
			||||||
 | 
					    if not return_dict:
 | 
				
			||||||
 | 
					        return tuple(v for v in [hidden_states, next_cache,
 | 
				
			||||||
 | 
					                                 all_hidden_states, all_self_attns] if v is not None)
 | 
				
			||||||
 | 
					    return BaseModelOutputWithPast(
 | 
				
			||||||
 | 
					        last_hidden_state=hidden_states,
 | 
				
			||||||
 | 
					        past_key_values=next_cache,
 | 
				
			||||||
 | 
					        hidden_states=all_hidden_states,
 | 
				
			||||||
 | 
					        attentions=all_self_attns,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def cohere_attention_forward(
 | 
				
			||||||
 | 
					    self,
 | 
				
			||||||
 | 
					    hidden_states: torch.Tensor,
 | 
				
			||||||
 | 
					    attention_mask: Optional[torch.Tensor] = None,
 | 
				
			||||||
 | 
					    position_ids: Optional[torch.LongTensor] = None,
 | 
				
			||||||
 | 
					    past_key_value: Optional[Tuple[torch.Tensor]] = None,
 | 
				
			||||||
 | 
					    output_attentions: bool = False,
 | 
				
			||||||
 | 
					    use_cache: bool = False,
 | 
				
			||||||
 | 
					    cache_position: Optional[torch.LongTensor] = None,
 | 
				
			||||||
 | 
					    **kwargs,
 | 
				
			||||||
 | 
					) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
 | 
				
			||||||
 | 
					    if use_quantize_kv_cache(self.q_proj, hidden_states):
 | 
				
			||||||
 | 
					        forward_function = cohere_attention_forward_quantized
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        forward_function = cohere_attention_forward_origin
 | 
				
			||||||
 | 
					    return forward_function(
 | 
				
			||||||
 | 
					        self=self,
 | 
				
			||||||
 | 
					        hidden_states=hidden_states,
 | 
				
			||||||
 | 
					        attention_mask=attention_mask,
 | 
				
			||||||
 | 
					        position_ids=position_ids,
 | 
				
			||||||
 | 
					        past_key_value=past_key_value,
 | 
				
			||||||
 | 
					        output_attentions=output_attentions,
 | 
				
			||||||
 | 
					        use_cache=use_cache,
 | 
				
			||||||
 | 
					        cache_position=cache_position,
 | 
				
			||||||
 | 
					        **kwargs,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def cohere_attention_forward_quantized(
 | 
				
			||||||
 | 
					    self,
 | 
				
			||||||
 | 
					    hidden_states: torch.Tensor,
 | 
				
			||||||
 | 
					    attention_mask: Optional[torch.Tensor] = None,
 | 
				
			||||||
 | 
					    position_ids: Optional[torch.LongTensor] = None,
 | 
				
			||||||
 | 
					    past_key_value: Optional[Tuple[torch.Tensor]] = None,
 | 
				
			||||||
 | 
					    output_attentions: bool = False,
 | 
				
			||||||
 | 
					    use_cache: bool = False,
 | 
				
			||||||
 | 
					    cache_position: Optional[torch.LongTensor] = None,
 | 
				
			||||||
 | 
					    **kwargs,
 | 
				
			||||||
 | 
					) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
 | 
				
			||||||
 | 
					    bsz, q_len, _ = hidden_states.size()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    query_states = self.q_proj(hidden_states)
 | 
				
			||||||
 | 
					    key_states = self.k_proj(hidden_states)
 | 
				
			||||||
 | 
					    value_states = self.v_proj(hidden_states)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim)
 | 
				
			||||||
 | 
					    key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
 | 
				
			||||||
 | 
					    if self.use_qk_norm:
 | 
				
			||||||
 | 
					        query_states = self.q_norm(query_states)
 | 
				
			||||||
 | 
					        key_states = self.k_norm(key_states)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    query_states = query_states.transpose(1, 2)
 | 
				
			||||||
 | 
					    key_states = key_states.transpose(1, 2)
 | 
				
			||||||
 | 
					    value_states = value_states.view(bsz, q_len,
 | 
				
			||||||
 | 
					                                     self.num_key_value_heads, self.head_dim).transpose(1, 2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    past_key_value = getattr(self, "past_key_value", past_key_value)
 | 
				
			||||||
 | 
					    kv_seq_len = key_states.shape[-2]
 | 
				
			||||||
 | 
					    if past_key_value is not None:
 | 
				
			||||||
 | 
					        kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
 | 
				
			||||||
 | 
					    cos, sin = self.rotary_emb(value_states, position_ids)
 | 
				
			||||||
 | 
					    query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if past_key_value is not None:
 | 
				
			||||||
 | 
					        # sin and cos are specific to RoPE models; position_ids needed for the static cache
 | 
				
			||||||
 | 
					        cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
 | 
				
			||||||
 | 
					        key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx,
 | 
				
			||||||
 | 
					                                                         cache_kwargs, new_layout=True)
 | 
				
			||||||
 | 
					    if q_len == 1 and query_states.device.type == 'xpu' and not self.training \
 | 
				
			||||||
 | 
					            and not hidden_states.requires_grad:
 | 
				
			||||||
 | 
					        import linear_q4_0
 | 
				
			||||||
 | 
					        attn_output = linear_q4_0.sdp_fp8(query_states, key_states, value_states,
 | 
				
			||||||
 | 
					                                          attention_mask)
 | 
				
			||||||
 | 
					        attn_weights = None
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        key_states, value_states = restore_fp8_kv_cache(key_states,
 | 
				
			||||||
 | 
					                                                        value_states, query_states.dtype)
 | 
				
			||||||
 | 
					        key_states = repeat_kv(key_states, self.num_key_value_groups)
 | 
				
			||||||
 | 
					        value_states = repeat_kv(value_states, self.num_key_value_groups)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        attn_weights = torch.matmul(query_states,
 | 
				
			||||||
 | 
					                                    key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if attention_mask is not None:  # no matter the length, we just slice it
 | 
				
			||||||
 | 
					            causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
 | 
				
			||||||
 | 
					            attn_weights = attn_weights + causal_mask
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # upcast attention to fp32
 | 
				
			||||||
 | 
					        attn_weights = nn.functional.softmax(attn_weights,
 | 
				
			||||||
 | 
					                                             dim=-1, dtype=torch.float32).to(query_states.dtype)
 | 
				
			||||||
 | 
					        attn_weights = nn.functional.dropout(attn_weights,
 | 
				
			||||||
 | 
					                                             p=self.attention_dropout, training=self.training)
 | 
				
			||||||
 | 
					        attn_output = torch.matmul(attn_weights, value_states)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    invalidInputError(attn_output.size() == (bsz, self.num_heads, q_len, self.head_dim),
 | 
				
			||||||
 | 
					                      "`attn_output` should be of size "
 | 
				
			||||||
 | 
					                      f"{(bsz, self.num_heads, q_len, self.head_dim)},"
 | 
				
			||||||
 | 
					                      f" but is {attn_output.size()}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    attn_output = attn_output.transpose(1, 2).contiguous()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    attn_output = self.o_proj(attn_output)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if not output_attentions:
 | 
				
			||||||
 | 
					        attn_weights = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return attn_output, attn_weights, past_key_value
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def cohere_attention_forward_origin(
 | 
				
			||||||
 | 
					    self,
 | 
				
			||||||
 | 
					    hidden_states: torch.Tensor,
 | 
				
			||||||
 | 
					    attention_mask: Optional[torch.Tensor] = None,
 | 
				
			||||||
 | 
					    position_ids: Optional[torch.LongTensor] = None,
 | 
				
			||||||
 | 
					    past_key_value: Optional[Tuple[torch.Tensor]] = None,
 | 
				
			||||||
 | 
					    output_attentions: bool = False,
 | 
				
			||||||
 | 
					    use_cache: bool = False,
 | 
				
			||||||
 | 
					    cache_position: Optional[torch.LongTensor] = None,
 | 
				
			||||||
 | 
					) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
 | 
				
			||||||
 | 
					    bsz, q_len, _ = hidden_states.size()
 | 
				
			||||||
 | 
					    device = hidden_states.device
 | 
				
			||||||
 | 
					    use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
 | 
				
			||||||
 | 
					    enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx)
 | 
				
			||||||
 | 
					    decoding_fast_path = use_decoding_fast_path(self.q_proj,
 | 
				
			||||||
 | 
					                                                use_fuse_rope,
 | 
				
			||||||
 | 
					                                                enough_kv_room,
 | 
				
			||||||
 | 
					                                                bsz * q_len)
 | 
				
			||||||
 | 
					    if decoding_fast_path:
 | 
				
			||||||
 | 
					        hidden_states = hidden_states.view(1, -1)
 | 
				
			||||||
 | 
					        cache_k = past_key_value.key_cache[self.layer_idx]
 | 
				
			||||||
 | 
					        cache_v = past_key_value.value_cache[self.layer_idx]
 | 
				
			||||||
 | 
					        kv_seq_len = cache_k.shape[-2]
 | 
				
			||||||
 | 
					        import linear_q4_0
 | 
				
			||||||
 | 
					        query_states, key_states, value_states = linear_q4_0.forward_qkv(hidden_states,
 | 
				
			||||||
 | 
					                                                                         self.q_proj.weight,
 | 
				
			||||||
 | 
					                                                                         self.k_proj.weight,
 | 
				
			||||||
 | 
					                                                                         self.v_proj.weight,
 | 
				
			||||||
 | 
					                                                                         position_ids,
 | 
				
			||||||
 | 
					                                                                         cache_k, cache_v,
 | 
				
			||||||
 | 
					                                                                         self.q_proj.weight.qtype,
 | 
				
			||||||
 | 
					                                                                         self.v_proj.weight.qtype,
 | 
				
			||||||
 | 
					                                                                         kv_seq_len,
 | 
				
			||||||
 | 
					                                                                         self.head_dim,
 | 
				
			||||||
 | 
					                                                                         self.rotary_emb.base,)
 | 
				
			||||||
 | 
					        kv_seq_len += 1
 | 
				
			||||||
 | 
					        # update past_key_value's seem_tokens and kv caches.
 | 
				
			||||||
 | 
					        if self.layer_idx == 0:
 | 
				
			||||||
 | 
					            past_key_value._seen_tokens = kv_seq_len
 | 
				
			||||||
 | 
					        past_key_value.key_cache[self.layer_idx] = key_states
 | 
				
			||||||
 | 
					        past_key_value.value_cache[self.layer_idx] = value_states
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        query_states = self.q_proj(hidden_states)
 | 
				
			||||||
 | 
					        key_states = self.k_proj(hidden_states)
 | 
				
			||||||
 | 
					        value_states = self.v_proj(hidden_states)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim)
 | 
				
			||||||
 | 
					        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
 | 
				
			||||||
 | 
					        if self.use_qk_norm:
 | 
				
			||||||
 | 
					            query_states = self.q_norm(query_states)
 | 
				
			||||||
 | 
					            key_states = self.k_norm(key_states)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        query_states = query_states.transpose(1, 2)
 | 
				
			||||||
 | 
					        key_states = key_states.transpose(1, 2)
 | 
				
			||||||
 | 
					        value_states = value_states.view(bsz, q_len, self.num_key_value_heads,
 | 
				
			||||||
 | 
					                                         self.head_dim).transpose(1, 2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        past_key_value = getattr(self, "past_key_value", past_key_value)
 | 
				
			||||||
 | 
					        kv_seq_len = key_states.shape[-2]
 | 
				
			||||||
 | 
					        if past_key_value is not None:
 | 
				
			||||||
 | 
					            if self.layer_idx is None:
 | 
				
			||||||
 | 
					                invalidInputError(
 | 
				
			||||||
 | 
					                    False,
 | 
				
			||||||
 | 
					                    "The cache structure has changed since version v4.36. "
 | 
				
			||||||
 | 
					                    f"If you are using {self.__class__.__name__} "
 | 
				
			||||||
 | 
					                    "for auto-regressive decoding with k/v caching, "
 | 
				
			||||||
 | 
					                    "please make sure to initialize the attention class with a layer index."
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					            kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
 | 
				
			||||||
 | 
					        cos, sin = self.rotary_emb(value_states, position_ids)
 | 
				
			||||||
 | 
					        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if past_key_value is not None:
 | 
				
			||||||
 | 
					            if self.layer_idx == 0:
 | 
				
			||||||
 | 
					                past_key_value._seen_tokens += key_states.shape[-2]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if len(past_key_value.key_cache) <= self.layer_idx:
 | 
				
			||||||
 | 
					                past_key_value.key_cache.append(key_states)
 | 
				
			||||||
 | 
					                past_key_value.value_cache.append(value_states)
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                cache_k = past_key_value.key_cache[self.layer_idx]
 | 
				
			||||||
 | 
					                cache_v = past_key_value.value_cache[self.layer_idx]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                if not enough_kv_room:
 | 
				
			||||||
 | 
					                    # allocate new
 | 
				
			||||||
 | 
					                    new_c_k, new_c_v = extend_kv_cache(bsz,
 | 
				
			||||||
 | 
					                                                       self.num_key_value_heads,  # Support GQA
 | 
				
			||||||
 | 
					                                                       self.head_dim,
 | 
				
			||||||
 | 
					                                                       cache_k.size(2),
 | 
				
			||||||
 | 
					                                                       kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
 | 
				
			||||||
 | 
					                                                       dtype=cache_k.dtype,
 | 
				
			||||||
 | 
					                                                       device=device)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    new_c_k[:] = cache_k
 | 
				
			||||||
 | 
					                    new_c_v[:] = cache_v
 | 
				
			||||||
 | 
					                    cache_k = new_c_k
 | 
				
			||||||
 | 
					                    cache_v = new_c_v
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                key_states, value_states = append_kv_cache(cache_k,
 | 
				
			||||||
 | 
					                                                           cache_v,
 | 
				
			||||||
 | 
					                                                           key_states,
 | 
				
			||||||
 | 
					                                                           value_states)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                # update past_key_value
 | 
				
			||||||
 | 
					                past_key_value.key_cache[self.layer_idx] = key_states
 | 
				
			||||||
 | 
					                past_key_value.value_cache[self.layer_idx] = value_states
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    key_states = repeat_kv(key_states, self.num_key_value_groups)
 | 
				
			||||||
 | 
					    value_states = repeat_kv(value_states, self.num_key_value_groups)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if not self.training and not hidden_states.requires_grad and \
 | 
				
			||||||
 | 
					            use_flash_attention(query_states, key_states, attention_mask):
 | 
				
			||||||
 | 
					        attn_output = F.scaled_dot_product_attention(query_states.to(device, dtype=torch.float16),
 | 
				
			||||||
 | 
					                                                     key_states.to(device, dtype=torch.float16),
 | 
				
			||||||
 | 
					                                                     value_states.to(device, dtype=torch.float16),
 | 
				
			||||||
 | 
					                                                     is_causal=True)
 | 
				
			||||||
 | 
					        attn_weights = None
 | 
				
			||||||
 | 
					    elif not self.training and not hidden_states.requires_grad and \
 | 
				
			||||||
 | 
					            use_esimd_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
 | 
				
			||||||
 | 
					        import linear_q4_0
 | 
				
			||||||
 | 
					        attn_output = linear_q4_0.sdp_fp16(query_states, key_states, value_states, attention_mask)
 | 
				
			||||||
 | 
					        attn_output = attn_output.view(query_states.shape)
 | 
				
			||||||
 | 
					        attn_weights = None
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        attn_weights = torch.matmul(query_states,
 | 
				
			||||||
 | 
					                                    key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if attention_mask is not None:  # no matter the length, we just slice it
 | 
				
			||||||
 | 
					            causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
 | 
				
			||||||
 | 
					            attn_weights = attn_weights + causal_mask
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # upcast attention to fp32
 | 
				
			||||||
 | 
					        attn_weights = nn.functional.softmax(attn_weights,
 | 
				
			||||||
 | 
					                                             dim=-1, dtype=torch.float32).to(query_states.dtype)
 | 
				
			||||||
 | 
					        attn_weights = nn.functional.dropout(attn_weights,
 | 
				
			||||||
 | 
					                                             p=self.attention_dropout, training=self.training)
 | 
				
			||||||
 | 
					        attn_output = torch.matmul(attn_weights, value_states)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    invalidInputError(attn_output.size() == (bsz, self.num_heads, q_len, self.head_dim),
 | 
				
			||||||
 | 
					                      "`attn_output` should be of size "
 | 
				
			||||||
 | 
					                      f"{(bsz, self.num_heads, q_len, self.head_dim)},"
 | 
				
			||||||
 | 
					                      f" but is {attn_output.size()}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    attn_output = attn_output.transpose(1, 2).contiguous()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    attn_output = self.o_proj(attn_output)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if not output_attentions:
 | 
				
			||||||
 | 
					        attn_weights = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return attn_output.to(hidden_states.dtype), attn_weights, past_key_value
 | 
				
			||||||
		Loading…
	
		Reference in a new issue