stablelm fp8 kv cache (#10672)

* stablelm fp8 kvcache

* update

* fix

* change to fp8 matmul

* fix style

* fix

* fix

* meet code review

* add comment
This commit is contained in:
Xin Qiu 2024-04-08 15:16:46 +08:00 committed by GitHub
parent 65127622aa
commit 1274cba79b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 244 additions and 30 deletions

View file

@ -633,7 +633,7 @@ def _optimize_pre(model):
del module.c_attn del module.c_attn
model.apply(split_qkv_proj_func) model.apply(split_qkv_proj_func)
if model.config.model_type == "stablelm": if model.config.model_type == "stablelm":
# For stablelm-zephyr-3b # For stablelm-zephyr-3b and stablelm-2-zephyr-1_6b
from ipex_llm.transformers.models.stablelm import merge_qkv from ipex_llm.transformers.models.stablelm import merge_qkv
model.apply(merge_qkv) model.apply(merge_qkv)
@ -1342,10 +1342,11 @@ def _optimize_post(model, lightweight_bmm=False):
module.BertEncoder, module.BertEncoder,
encoder_forward) encoder_forward)
elif model.config.model_type == 'stablelm': elif model.config.model_type == 'stablelm':
# For stablelm-zephyr-3b # For stablelm-zephyr-3b and stablelm-2-zephyr-1_6b
modeling_module_name = model.__class__.__module__ modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name) module = importlib.import_module(modeling_module_name)
from ipex_llm.transformers.models.stablelm import stablelm_attention_forward from ipex_llm.transformers.models.stablelm import stablelm_attention_forward
from ipex_llm.transformers.models.stablelm import stablelm_model_forward
convert_forward(model, convert_forward(model,
module.StableLmAttention, module.StableLmAttention,
stablelm_attention_forward stablelm_attention_forward
@ -1353,5 +1354,8 @@ def _optimize_post(model, lightweight_bmm=False):
convert_forward(model, convert_forward(model,
module.StableLmMLP, module.StableLmMLP,
llama_mlp_forward) llama_mlp_forward)
convert_forward(model,
module.StableLmModel,
stablelm_model_forward
)
return model return model

View file

@ -38,17 +38,20 @@
# #
import math import math
from typing import Optional, Tuple from typing import Optional, Tuple, List, Union
import torch import torch
from torch import nn from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
from transformers.models.stablelm.modeling_stablelm import StableLmAttention from transformers.models.stablelm.modeling_stablelm import StableLmAttention, StableLmModel
from transformers.modeling_outputs import BaseModelOutputWithPast
from ipex_llm.utils.common import invalidInputError from ipex_llm.utils.common import invalidInputError
from ipex_llm.transformers.models.utils import extend_kv_cache, append_kv_cache from ipex_llm.transformers.models.utils import extend_kv_cache, append_kv_cache
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, \ from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, \
apply_rotary_pos_emb_cache_freq_xpu apply_rotary_pos_emb_cache_freq_xpu
from ipex_llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_cache, \
restore_fp8_kv_cache, use_quantize_kv_cache
from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_36 from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_36
from ipex_llm.transformers.models.utils import use_flash_attention, use_esimd_sdp from ipex_llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
from ipex_llm.transformers.models.mistral import should_use_fuse_rope, repeat_kv from ipex_llm.transformers.models.mistral import should_use_fuse_rope, repeat_kv
@ -87,7 +90,68 @@ def merge_qkv(module: torch.nn.Module):
del module.q_proj, module.k_proj, module.v_proj del module.q_proj, module.k_proj, module.v_proj
def stablelm_model_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
from ipex_llm.transformers.kv import DynamicFp8Cache
use_cache = use_cache if use_cache is not None else self.config.use_cache
if use_cache and use_quantize_kv_cache_stablelm(self.layers[0].self_attn.head_dim,
self.layers[0].mlp.up_proj,
input_ids):
if not isinstance(past_key_values, DynamicFp8Cache):
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
return StableLmModel.forward(
self=self,
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
def use_quantize_kv_cache_stablelm(head_dim: int, linear: torch.nn.Module, x: torch.Tensor) -> bool:
return (head_dim == 64 or head_dim == 128) and use_quantize_kv_cache(linear, x)
def stablelm_attention_forward( def stablelm_attention_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if use_quantize_kv_cache_stablelm(self.head_dim, self.o_proj, hidden_states):
forward_function = stablelm_attention_forward_quantized
else:
forward_function = stablelm_attention_forward_original
return forward_function(
self=self,
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
def stablelm_attention_forward_original(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor]=None, attention_mask: Optional[torch.Tensor]=None,
@ -116,8 +180,7 @@ def stablelm_attention_forward(
kv_seq_len = key_states.shape[-2] kv_seq_len = key_states.shape[-2]
if past_key_value is not None: if past_key_value is not None:
if self.layer_idx is None: invalidInputError(self.layer_idx is not None,
invalidInputError(False,
"The cache structure has changed since version v4.36. " "The cache structure has changed since version v4.36. "
f"If you are using {self.__class__.__name__} for " f"If you are using {self.__class__.__name__} for "
"auto-regressive decodingwith k/v caching, please make sure " "auto-regressive decodingwith k/v caching, please make sure "
@ -134,6 +197,7 @@ def stablelm_attention_forward(
key_states[..., self.rotary_emb.dim:], key_states[..., self.rotary_emb.dim:],
) )
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
# [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
if use_fuse_rope: if use_fuse_rope:
query_rot, key_rot = apply_rotary_pos_emb_cache_freq_xpu(query_rot, query_rot, key_rot = apply_rotary_pos_emb_cache_freq_xpu(query_rot,
key_rot, key_rot,
@ -142,7 +206,6 @@ def stablelm_attention_forward(
"stablelm", "stablelm",
position_ids) position_ids)
else: else:
# [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
query_rot, key_rot = apply_rotary_pos_emb(query_rot, query_rot, key_rot = apply_rotary_pos_emb(query_rot,
key_rot, key_rot,
cos, cos,
@ -214,20 +277,16 @@ def stablelm_attention_forward(
query_states, query_states,
key_states.transpose(2, 3)) / math.sqrt(self.head_dim) key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
invalidInputError( invalidInputError(
False, attn_weights.size() == (bsz, self.num_heads, q_len, kv_seq_len),
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}," f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)},"
f" but is {attn_weights.size()}" f" but is {attn_weights.size()}")
)
if attention_mask is not None: if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
invalidInputError( invalidInputError(
False, attention_mask.size() == (bsz, 1, q_len, kv_seq_len),
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}," f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)},"
f" but is {attention_mask.size()}" f" but is {attention_mask.size()}")
)
attn_weights = attn_weights + attention_mask attn_weights = attn_weights + attention_mask
@ -238,12 +297,10 @@ def stablelm_attention_forward(
attn_output = torch.matmul(attn_weights, value_states) attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
invalidInputError( invalidInputError(
False, attn_output.size() == (bsz, self.num_heads, q_len, self.head_dim),
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}," f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)},"
f" but is {attn_output.size()}" f" but is {attn_output.size()}")
)
attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
@ -253,3 +310,156 @@ def stablelm_attention_forward(
attn_weights = None attn_weights = None
return attn_output.to(original_dtype), attn_weights, past_key_value return attn_output.to(original_dtype), attn_weights, past_key_value
def stablelm_attention_forward_quantized(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor]=None,
position_ids: Optional[torch.LongTensor]=None,
past_key_value: Optional[Cache]=None,
output_attentions: bool=False,
use_cache: bool=False,
**kwargs
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
bsz, q_len, hidden_size = hidden_states.size()
device = hidden_states.device
# for flash attention
original_dtype = hidden_states.dtype
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
qkv = self.qkv_proj(hidden_states)
qkv = qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim)
qkv = qkv.transpose(1, 2)
query_states, key_states, value_states = qkv.split([self.num_heads,
self.num_heads,
self.num_heads], dim=1)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
invalidInputError(
self.layer_idx is not None,
f"The cache structure has changed since version v4.36. "
"If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, "
"please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
# Partial rotary embedding
query_rot, query_pass = (
query_states[..., : self.rotary_emb.dim],
query_states[..., self.rotary_emb.dim:],
)
key_rot, key_pass = (
key_states[..., : self.rotary_emb.dim],
key_states[..., self.rotary_emb.dim:],
)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
# [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
if use_fuse_rope:
query_rot, key_rot = apply_rotary_pos_emb_cache_freq_xpu(query_rot,
key_rot,
sin,
cos,
"stablelm",
position_ids)
else:
query_rot, key_rot = apply_rotary_pos_emb(query_rot,
key_rot,
cos,
sin,
position_ids,
"stablelm")
# [batch_size, seq_length, num_heads, head_dim]
query_states = torch.cat((query_rot, query_pass), dim=-1)
key_states = torch.cat((key_rot, key_pass), dim=-1)
kv_seq_len = key_states.shape[-2]
if len(past_key_value.key_cache) <= self.layer_idx:
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3))
attn_weights = attn_weights / math.sqrt(self.head_dim)
invalidInputError(
attn_weights.size() == (bsz, self.num_heads, q_len, kv_seq_len),
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}"
f", but is {attn_weights.size()}")
if attention_mask is not None:
invalidInputError(
attention_mask.size() == (bsz, 1, q_len, kv_seq_len),
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)},"
f" but is {attention_mask.size()}")
attn_weights = attn_weights + attention_mask
# at inference time, for memory considerations, may not need to upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1).to(query_states.dtype)
attn_weights = self.attention_dropout(attn_weights)
attn_output = torch.matmul(attn_weights, value_states)
invalidInputError(
attn_output.size() == (bsz, self.num_heads, q_len, self.head_dim),
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}"
f", but is {attn_output.size()}")
if use_cache:
cache_kwargs = None
key_states, value_states = past_key_value.update(key_states, value_states,
self.layer_idx, cache_kwargs)
else:
cache_kwargs = None # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states,
self.layer_idx, cache_kwargs)
kv_seq_len = key_states.shape[-2]
if query_states.size(2) != 1 or query_states.device.type != 'xpu':
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
query_states.dtype)
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3))
else:
import linear_q4_0
attn_weights = linear_q4_0.query_key_fp8_matmul(query_states, key_states)
attn_weights = attn_weights / math.sqrt(self.head_dim)
invalidInputError(
attn_weights.size() == (bsz, self.num_heads, q_len, kv_seq_len),
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}"
f", but is {attn_weights.size()}")
if attention_mask is not None:
invalidInputError(
attention_mask.size() == (bsz, 1, q_len, kv_seq_len),
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)},"
f" but is {attention_mask.size()}")
attn_weights = attn_weights + attention_mask
# at inference time, for memory considerations, may not need to upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
attn_weights = self.attention_dropout(attn_weights)
if query_states.size(2) != 1 or query_states.device.type != 'xpu':
attn_output = torch.matmul(attn_weights, value_states)
else:
import linear_q4_0
attn_output = linear_q4_0.attn_value_fp8_matmul(attn_weights,
value_states.transpose(-1, -2))
attn_output_size = (bsz, self.num_heads, q_len, self.head_dim)
invalidInputError(attn_output.size() == attn_output_size,
f"`attn_output` should be of size {attn_output_size},"
f" but is {attn_output.size()}")
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output.to(original_dtype), attn_weights, past_key_value