stablelm fp8 kv cache (#10672)

* stablelm fp8 kvcache

* update

* fix

* change to fp8 matmul

* fix style

* fix

* fix

* meet code review

* add comment
This commit is contained in:
Xin Qiu 2024-04-08 15:16:46 +08:00 committed by GitHub
parent 65127622aa
commit 1274cba79b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 244 additions and 30 deletions

View file

@ -633,7 +633,7 @@ def _optimize_pre(model):
del module.c_attn
model.apply(split_qkv_proj_func)
if model.config.model_type == "stablelm":
# For stablelm-zephyr-3b
# For stablelm-zephyr-3b and stablelm-2-zephyr-1_6b
from ipex_llm.transformers.models.stablelm import merge_qkv
model.apply(merge_qkv)
@ -1342,10 +1342,11 @@ def _optimize_post(model, lightweight_bmm=False):
module.BertEncoder,
encoder_forward)
elif model.config.model_type == 'stablelm':
# For stablelm-zephyr-3b
# For stablelm-zephyr-3b and stablelm-2-zephyr-1_6b
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)
from ipex_llm.transformers.models.stablelm import stablelm_attention_forward
from ipex_llm.transformers.models.stablelm import stablelm_model_forward
convert_forward(model,
module.StableLmAttention,
stablelm_attention_forward
@ -1353,5 +1354,8 @@ def _optimize_post(model, lightweight_bmm=False):
convert_forward(model,
module.StableLmMLP,
llama_mlp_forward)
convert_forward(model,
module.StableLmModel,
stablelm_model_forward
)
return model

View file

@ -38,17 +38,20 @@
#
import math
from typing import Optional, Tuple
from typing import Optional, Tuple, List, Union
import torch
from torch import nn
import torch.nn.functional as F
from transformers.models.stablelm.modeling_stablelm import StableLmAttention
from transformers.models.stablelm.modeling_stablelm import StableLmAttention, StableLmModel
from transformers.modeling_outputs import BaseModelOutputWithPast
from ipex_llm.utils.common import invalidInputError
from ipex_llm.transformers.models.utils import extend_kv_cache, append_kv_cache
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, \
apply_rotary_pos_emb_cache_freq_xpu
from ipex_llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_cache, \
restore_fp8_kv_cache, use_quantize_kv_cache
from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_36
from ipex_llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
from ipex_llm.transformers.models.mistral import should_use_fuse_rope, repeat_kv
@ -87,7 +90,68 @@ def merge_qkv(module: torch.nn.Module):
del module.q_proj, module.k_proj, module.v_proj
def stablelm_model_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
from ipex_llm.transformers.kv import DynamicFp8Cache
use_cache = use_cache if use_cache is not None else self.config.use_cache
if use_cache and use_quantize_kv_cache_stablelm(self.layers[0].self_attn.head_dim,
self.layers[0].mlp.up_proj,
input_ids):
if not isinstance(past_key_values, DynamicFp8Cache):
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
return StableLmModel.forward(
self=self,
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
def use_quantize_kv_cache_stablelm(head_dim: int, linear: torch.nn.Module, x: torch.Tensor) -> bool:
return (head_dim == 64 or head_dim == 128) and use_quantize_kv_cache(linear, x)
def stablelm_attention_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if use_quantize_kv_cache_stablelm(self.head_dim, self.o_proj, hidden_states):
forward_function = stablelm_attention_forward_quantized
else:
forward_function = stablelm_attention_forward_original
return forward_function(
self=self,
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
def stablelm_attention_forward_original(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor]=None,
@ -116,8 +180,7 @@ def stablelm_attention_forward(
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
if self.layer_idx is None:
invalidInputError(False,
invalidInputError(self.layer_idx is not None,
"The cache structure has changed since version v4.36. "
f"If you are using {self.__class__.__name__} for "
"auto-regressive decodingwith k/v caching, please make sure "
@ -134,6 +197,7 @@ def stablelm_attention_forward(
key_states[..., self.rotary_emb.dim:],
)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
# [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
if use_fuse_rope:
query_rot, key_rot = apply_rotary_pos_emb_cache_freq_xpu(query_rot,
key_rot,
@ -142,7 +206,6 @@ def stablelm_attention_forward(
"stablelm",
position_ids)
else:
# [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
query_rot, key_rot = apply_rotary_pos_emb(query_rot,
key_rot,
cos,
@ -214,20 +277,16 @@ def stablelm_attention_forward(
query_states,
key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
invalidInputError(
False,
attn_weights.size() == (bsz, self.num_heads, q_len, kv_seq_len),
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)},"
f" but is {attn_weights.size()}"
)
f" but is {attn_weights.size()}")
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
invalidInputError(
False,
attention_mask.size() == (bsz, 1, q_len, kv_seq_len),
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)},"
f" but is {attention_mask.size()}"
)
f" but is {attention_mask.size()}")
attn_weights = attn_weights + attention_mask
@ -238,12 +297,10 @@ def stablelm_attention_forward(
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
invalidInputError(
False,
attn_output.size() == (bsz, self.num_heads, q_len, self.head_dim),
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)},"
f" but is {attn_output.size()}"
)
f" but is {attn_output.size()}")
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
@ -253,3 +310,156 @@ def stablelm_attention_forward(
attn_weights = None
return attn_output.to(original_dtype), attn_weights, past_key_value
def stablelm_attention_forward_quantized(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor]=None,
position_ids: Optional[torch.LongTensor]=None,
past_key_value: Optional[Cache]=None,
output_attentions: bool=False,
use_cache: bool=False,
**kwargs
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
bsz, q_len, hidden_size = hidden_states.size()
device = hidden_states.device
# for flash attention
original_dtype = hidden_states.dtype
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
qkv = self.qkv_proj(hidden_states)
qkv = qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim)
qkv = qkv.transpose(1, 2)
query_states, key_states, value_states = qkv.split([self.num_heads,
self.num_heads,
self.num_heads], dim=1)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
invalidInputError(
self.layer_idx is not None,
f"The cache structure has changed since version v4.36. "
"If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, "
"please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
# Partial rotary embedding
query_rot, query_pass = (
query_states[..., : self.rotary_emb.dim],
query_states[..., self.rotary_emb.dim:],
)
key_rot, key_pass = (
key_states[..., : self.rotary_emb.dim],
key_states[..., self.rotary_emb.dim:],
)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
# [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
if use_fuse_rope:
query_rot, key_rot = apply_rotary_pos_emb_cache_freq_xpu(query_rot,
key_rot,
sin,
cos,
"stablelm",
position_ids)
else:
query_rot, key_rot = apply_rotary_pos_emb(query_rot,
key_rot,
cos,
sin,
position_ids,
"stablelm")
# [batch_size, seq_length, num_heads, head_dim]
query_states = torch.cat((query_rot, query_pass), dim=-1)
key_states = torch.cat((key_rot, key_pass), dim=-1)
kv_seq_len = key_states.shape[-2]
if len(past_key_value.key_cache) <= self.layer_idx:
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3))
attn_weights = attn_weights / math.sqrt(self.head_dim)
invalidInputError(
attn_weights.size() == (bsz, self.num_heads, q_len, kv_seq_len),
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}"
f", but is {attn_weights.size()}")
if attention_mask is not None:
invalidInputError(
attention_mask.size() == (bsz, 1, q_len, kv_seq_len),
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)},"
f" but is {attention_mask.size()}")
attn_weights = attn_weights + attention_mask
# at inference time, for memory considerations, may not need to upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1).to(query_states.dtype)
attn_weights = self.attention_dropout(attn_weights)
attn_output = torch.matmul(attn_weights, value_states)
invalidInputError(
attn_output.size() == (bsz, self.num_heads, q_len, self.head_dim),
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}"
f", but is {attn_output.size()}")
if use_cache:
cache_kwargs = None
key_states, value_states = past_key_value.update(key_states, value_states,
self.layer_idx, cache_kwargs)
else:
cache_kwargs = None # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states,
self.layer_idx, cache_kwargs)
kv_seq_len = key_states.shape[-2]
if query_states.size(2) != 1 or query_states.device.type != 'xpu':
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
query_states.dtype)
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3))
else:
import linear_q4_0
attn_weights = linear_q4_0.query_key_fp8_matmul(query_states, key_states)
attn_weights = attn_weights / math.sqrt(self.head_dim)
invalidInputError(
attn_weights.size() == (bsz, self.num_heads, q_len, kv_seq_len),
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}"
f", but is {attn_weights.size()}")
if attention_mask is not None:
invalidInputError(
attention_mask.size() == (bsz, 1, q_len, kv_seq_len),
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)},"
f" but is {attention_mask.size()}")
attn_weights = attn_weights + attention_mask
# at inference time, for memory considerations, may not need to upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
attn_weights = self.attention_dropout(attn_weights)
if query_states.size(2) != 1 or query_states.device.type != 'xpu':
attn_output = torch.matmul(attn_weights, value_states)
else:
import linear_q4_0
attn_output = linear_q4_0.attn_value_fp8_matmul(attn_weights,
value_states.transpose(-1, -2))
attn_output_size = (bsz, self.num_heads, q_len, self.head_dim)
invalidInputError(attn_output.size() == attn_output_size,
f"`attn_output` should be of size {attn_output_size},"
f" but is {attn_output.size()}")
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output.to(original_dtype), attn_weights, past_key_value