From d90cd977d04481a5bacdf72aac642fefb4b829af Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Tue, 4 Jun 2024 13:14:43 +0800 Subject: [PATCH] refactor stablelm (#11195) --- .../ipex_llm/transformers/models/stablelm.py | 368 +++--------------- 1 file changed, 61 insertions(+), 307 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/models/stablelm.py b/python/llm/src/ipex_llm/transformers/models/stablelm.py index c8a84557..441b49cf 100644 --- a/python/llm/src/ipex_llm/transformers/models/stablelm.py +++ b/python/llm/src/ipex_llm/transformers/models/stablelm.py @@ -38,31 +38,19 @@ # import math -from typing import Optional, Tuple, List, Union +from typing import Optional, Tuple, List import torch -from torch import nn -import torch.nn.functional as F +from transformers.cache_utils import Cache +from transformers.models.stablelm.modeling_stablelm import repeat_kv from transformers.models.stablelm.modeling_stablelm import StableLmAttention, StableLmModel -from transformers.modeling_outputs import BaseModelOutputWithPast -from ipex_llm.utils.common import invalidInputError -from ipex_llm.transformers.models.utils import extend_kv_cache, append_kv_cache from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, \ apply_rotary_pos_emb_cache_freq_xpu -from ipex_llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_cache, \ - restore_fp8_kv_cache, use_quantize_kv_cache -from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_36 -from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp -from ipex_llm.transformers.models.mistral import should_use_fuse_rope, repeat_kv -try: - from transformers.cache_utils import Cache -except ImportError: - Cache = Tuple[torch.Tensor] - -import os - -KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)) +from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal +from ipex_llm.transformers.models.utils import restore_fp8_kv_cache, use_quantize_kv_cache +from ipex_llm.transformers.models.utils import should_use_fuse_rope +from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache def merge_qkv(module: torch.nn.Module): @@ -92,24 +80,26 @@ def merge_qkv(module: torch.nn.Module): def stablelm_model_forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, -) -> Union[Tuple, BaseModelOutputWithPast]: - from ipex_llm.transformers.kv import DynamicFp8Cache + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, +): + # IPEX-LLM OPT: kv cache and quantize kv cache use_cache = use_cache if use_cache is not None else self.config.use_cache - if use_cache and use_quantize_kv_cache_stablelm(self.layers[0].self_attn.head_dim, - self.layers[0].mlp.up_proj, - input_ids): - if not isinstance(past_key_values, DynamicFp8Cache): + use_quantize_kv = (self.layers[0].self_attn.head_dim in [64, 96, 128] + and use_quantize_kv_cache(self.layers[0].mlp.up_proj, input_ids)) + if use_cache: + if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache): past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values) + if not use_quantize_kv and not isinstance(past_key_values, DynamicNormalCache): + past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values) return StableLmModel.forward( self=self, input_ids=input_ids, @@ -124,10 +114,6 @@ def stablelm_model_forward( ) -def use_quantize_kv_cache_stablelm(head_dim: int, linear: torch.nn.Module, x: torch.Tensor) -> bool: - return (head_dim == 64 or head_dim == 128) and use_quantize_kv_cache(linear, x) - - def stablelm_attention_forward( self, hidden_states: torch.Tensor, @@ -137,55 +123,17 @@ def stablelm_attention_forward( output_attentions: bool = False, use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if use_quantize_kv_cache_stablelm(self.head_dim, self.o_proj, hidden_states): - forward_function = stablelm_attention_forward_quantized - else: - forward_function = stablelm_attention_forward_original - return forward_function( - self=self, - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - - -def stablelm_attention_forward_original( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor]=None, - position_ids: Optional[torch.LongTensor]=None, - past_key_value: Optional[Cache]=None, - output_attentions: bool=False, - use_cache: bool=False, - **kwargs -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: - bsz, q_len, _ = hidden_states.size() - device = hidden_states.device - # for flash attention - original_dtype = hidden_states.dtype - - use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) - enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx) qkv = self.qkv_proj(hidden_states) qkv = qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim) qkv = qkv.transpose(1, 2) query_states, key_states, value_states = qkv.split([self.num_heads, - self.num_heads, - self.num_heads], dim=1) + self.num_key_value_heads, + self.num_key_value_heads], dim=1) kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - invalidInputError(self.layer_idx is not None, - "The cache structure has changed since version v4.36. " - f"If you are using {self.__class__.__name__} for " - "auto-regressive decodingwith k/v caching, please make sure " - "to initialize the attention class with a layer index.") kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) # Partial rotary embedding @@ -198,8 +146,8 @@ def stablelm_attention_forward_original( key_states[..., self.rotary_emb.dim:], ) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor] - if use_fuse_rope: + # [batch_size, num_heads, seq_length, head_dim // config.partial_rotary_factor] + if should_use_fuse_rope(hidden_states, position_ids, self.training): query_rot, key_rot = apply_rotary_pos_emb_cache_freq_xpu(query_rot, key_rot, sin, @@ -214,94 +162,52 @@ def stablelm_attention_forward_original( position_ids, "stablelm") - # [batch_size, seq_length, num_heads, head_dim] + # [batch_size, num_heads, seq_length, head_dim] query_states = torch.cat((query_rot, query_pass), dim=-1) key_states = torch.cat((key_rot, key_pass), dim=-1) - if past_key_value is not None: - # update the number of seen tokens - if self.layer_idx == 0: - past_key_value.seen_tokens += key_states.shape[-2] + key_states, value_states = past_key_value.update(key_states, value_states, + self.layer_idx, None) - # reuse k, v, self_attention - # update `past_key_value` with `key_states` and `value_states` for layer `layer_idx` - if len(past_key_value.key_cache) <= self.layer_idx: - past_key_value.key_cache.append(key_states) - past_key_value.value_cache.append(value_states) - else: - cache_k = past_key_value.key_cache[self.layer_idx] - cache_v = past_key_value.value_cache[self.layer_idx] - - if not enough_kv_room: - # allocate new - new_c_k, new_c_v = extend_kv_cache(bsz, - self.num_key_value_heads, # Support GQA - self.head_dim, - cache_k.size(2), - kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH, - dtype=cache_k.dtype, - device=device) - - new_c_k[:] = cache_k - new_c_v[:] = cache_v - cache_k = new_c_k - cache_v = new_c_v - - key_states, value_states = append_kv_cache(cache_k, cache_v, - key_states, value_states) - - # update past_key_value - past_key_value.key_cache[self.layer_idx] = key_states - past_key_value.value_cache[self.layer_idx] = value_states - - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - if not self.training and not hidden_states.requires_grad and \ - use_flash_attention(query_states, key_states, attention_mask): - attn_output = F.scaled_dot_product_attention(query_states.to(device, dtype=torch.float16), - key_states.to(device, dtype=torch.float16), - value_states.to(device, dtype=torch.float16), - is_causal=True) - attn_weights = None - elif not self.training and not hidden_states.requires_grad and \ - use_sdp(q_len, key_states.shape[2], self.head_dim, query_states): + # IPEX-LLM OPT: sdp + attn_weights = None + if use_sdp(q_len, kv_seq_len, self.head_dim, query_states): import xe_addons - attn_output = xe_addons.sdp(query_states, key_states, value_states, - attention_mask) - attn_output = attn_output.view(query_states.shape) - attn_weights = None + if isinstance(past_key_value, DynamicFp8Cache): + attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states, + attention_mask) + else: + attn_output = xe_addons.sdp(query_states, key_states, value_states, + attention_mask) + elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training): + import xe_addons + if isinstance(past_key_value, DynamicFp8Cache): + attn_output = xe_addons.sdp_fp8_causal(query_states, key_states, + value_states, attention_mask) + else: + attn_output = xe_addons.sdp_causal(query_states, key_states, + value_states, attention_mask) else: - attn_weights = torch.matmul( - query_states, - key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + if isinstance(past_key_value, DynamicFp8Cache): + key_states, value_states = restore_fp8_kv_cache(key_states, value_states, + query_states.dtype) - invalidInputError( - attn_weights.size() == (bsz, self.num_heads, q_len, kv_seq_len), - f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}," - f" but is {attn_weights.size()}") + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, + key_states.transpose(2, 3)) / math.sqrt(self.head_dim) if attention_mask is not None: - invalidInputError( - attention_mask.size() == (bsz, 1, q_len, kv_seq_len), - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}," - f" but is {attention_mask.size()}") - attn_weights = attn_weights + attention_mask # upcast attention to fp32 - attn_weights = \ - nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query_states.dtype) + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, + dtype=torch.float32).to(value_states.dtype) attn_weights = self.attention_dropout(attn_weights) - attn_output = torch.matmul(attn_weights, value_states) - invalidInputError( - attn_output.size() == (bsz, self.num_heads, q_len, self.head_dim), - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}," - f" but is {attn_output.size()}") - attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) @@ -309,156 +215,4 @@ def stablelm_attention_forward_original( if not output_attentions: attn_weights = None - return attn_output.to(original_dtype), attn_weights, past_key_value - - -def stablelm_attention_forward_quantized( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor]=None, - position_ids: Optional[torch.LongTensor]=None, - past_key_value: Optional[Cache]=None, - output_attentions: bool=False, - use_cache: bool=False, - **kwargs -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: - bsz, q_len, hidden_size = hidden_states.size() - device = hidden_states.device - # for flash attention - original_dtype = hidden_states.dtype - - use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) - qkv = self.qkv_proj(hidden_states) - qkv = qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim) - qkv = qkv.transpose(1, 2) - query_states, key_states, value_states = qkv.split([self.num_heads, - self.num_heads, - self.num_heads], dim=1) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - invalidInputError( - self.layer_idx is not None, - f"The cache structure has changed since version v4.36. " - "If you are using {self.__class__.__name__} " - "for auto-regressive decoding with k/v caching, " - "please make sure to initialize the attention class " - "with a layer index." - ) - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - - # Partial rotary embedding - query_rot, query_pass = ( - query_states[..., : self.rotary_emb.dim], - query_states[..., self.rotary_emb.dim:], - ) - key_rot, key_pass = ( - key_states[..., : self.rotary_emb.dim], - key_states[..., self.rotary_emb.dim:], - ) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor] - if use_fuse_rope: - query_rot, key_rot = apply_rotary_pos_emb_cache_freq_xpu(query_rot, - key_rot, - sin, - cos, - "stablelm", - position_ids) - else: - query_rot, key_rot = apply_rotary_pos_emb(query_rot, - key_rot, - cos, - sin, - position_ids, - "stablelm") - - # [batch_size, seq_length, num_heads, head_dim] - query_states = torch.cat((query_rot, query_pass), dim=-1) - key_states = torch.cat((key_rot, key_pass), dim=-1) - - kv_seq_len = key_states.shape[-2] - if len(past_key_value.key_cache) <= self.layer_idx: - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) - attn_weights = attn_weights / math.sqrt(self.head_dim) - - invalidInputError( - attn_weights.size() == (bsz, self.num_heads, q_len, kv_seq_len), - f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}" - f", but is {attn_weights.size()}") - - if attention_mask is not None: - invalidInputError( - attention_mask.size() == (bsz, 1, q_len, kv_seq_len), - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}," - f" but is {attention_mask.size()}") - attn_weights = attn_weights + attention_mask - - # at inference time, for memory considerations, may not need to upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1).to(query_states.dtype) - attn_weights = self.attention_dropout(attn_weights) - - attn_output = torch.matmul(attn_weights, value_states) - - invalidInputError( - attn_output.size() == (bsz, self.num_heads, q_len, self.head_dim), - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}" - f", but is {attn_output.size()}") - if use_cache: - cache_kwargs = None - key_states, value_states = past_key_value.update(key_states, value_states, - self.layer_idx, cache_kwargs) - else: - cache_kwargs = None # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, - self.layer_idx, cache_kwargs) - kv_seq_len = key_states.shape[-2] - if query_states.size(2) != 1 or query_states.device.type != 'xpu': - key_states, value_states = restore_fp8_kv_cache(key_states, value_states, - query_states.dtype) - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) - else: - import xe_addons - attn_weights = xe_addons.query_key_fp8_matmul(query_states, key_states) - - attn_weights = attn_weights / math.sqrt(self.head_dim) - - invalidInputError( - attn_weights.size() == (bsz, self.num_heads, q_len, kv_seq_len), - f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}" - f", but is {attn_weights.size()}") - - if attention_mask is not None: - invalidInputError( - attention_mask.size() == (bsz, 1, q_len, kv_seq_len), - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}," - f" but is {attention_mask.size()}") - attn_weights = attn_weights + attention_mask - - # at inference time, for memory considerations, may not need to upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - attn_weights = self.attention_dropout(attn_weights) - - if query_states.size(2) != 1 or query_states.device.type != 'xpu': - attn_output = torch.matmul(attn_weights, value_states) - else: - import xe_addons - attn_output = xe_addons.attn_value_fp8_matmul(attn_weights, value_states) - - attn_output_size = (bsz, self.num_heads, q_len, self.head_dim) - invalidInputError(attn_output.size() == attn_output_size, - f"`attn_output` should be of size {attn_output_size}," - f" but is {attn_output.size()}") - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output.to(original_dtype), attn_weights, past_key_value + return attn_output, attn_weights, past_key_value