From 31ce3e0c1344c4856180df607042ddf73658ef3a Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Fri, 17 May 2024 16:25:30 +0800 Subject: [PATCH] refactor baichuan2-13b (#11064) --- .../ipex_llm/transformers/models/baichuan2.py | 275 +++--------------- .../src/ipex_llm/transformers/models/utils.py | 47 +++ 2 files changed, 88 insertions(+), 234 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/models/baichuan2.py b/python/llm/src/ipex_llm/transformers/models/baichuan2.py index d5a93372..f72adedb 100644 --- a/python/llm/src/ipex_llm/transformers/models/baichuan2.py +++ b/python/llm/src/ipex_llm/transformers/models/baichuan2.py @@ -23,19 +23,13 @@ from typing import Optional, Tuple import torch import torch.utils.checkpoint from torch.nn import functional as F -from ipex_llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_cache, \ - restore_fp8_kv_cache, use_quantize_kv_cache -from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, \ - append_kv_cache, is_enough_kv_cache_room_4_31 +from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache +from ipex_llm.transformers.models.utils import update_past_key_value from ipex_llm.transformers.models.utils import should_use_fuse_rope from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp, use_sdp_causal from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, SILU from ipex_llm.transformers.models.utils import mlp_fusion_check import warnings -import os - - -KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)) def pre_compute_inv_freq(module: torch.nn.Module): @@ -114,52 +108,16 @@ def baichuan_attention_forward_7b( # IPEX-LLM OPT: kv cache and quantize kv use_quantize_kv = use_quantize_kv_cache(self.W_pack, hidden_states) - if use_quantize_kv: - if past_key_value is None: - k_cache, v_cache = init_fp8_kv_cache( - bsz, self.num_heads, kv_seq_len, self.head_dim, - device=device - ) - else: - k_cache, v_cache = past_key_value - key_states, value_states = append_fp8_kv_cache(k_cache, v_cache, - key_states, value_states) - else: - if past_key_value is None: - max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH - k_cache, v_cache = init_kv_cache(bsz, - self.num_heads, - self.head_dim, - kv_seq_len, - max_cache_length, - dtype=key_states.dtype, - device=device) - k_cache[...] = key_states - v_cache[...] = value_states - key_states = k_cache - value_states = v_cache - else: - k_cache, v_cache = past_key_value - if k_cache.stride(1) < kv_seq_len * k_cache.size(3): - max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH - new_k_cache, new_v_cache = extend_kv_cache(bsz, - self.num_heads, - self.head_dim, - k_cache.size(2), - max_cache_length, - dtype=k_cache.dtype, - device=device) - new_k_cache[...] = k_cache - new_v_cache[...] = v_cache - k_cache = new_k_cache - v_cache = new_v_cache - key_states, value_states = append_kv_cache(k_cache, v_cache, key_states, value_states) - + key_states, value_states = update_past_key_value( + past_key_value, key_states, value_states, + kv_seq_len, use_quantize_kv, device + ) past_key_value = (key_states, value_states) if use_cache else None if self.training: warnings.warn("xops is not supported on Intel GPU, so just use normal implementation") + # IPEX-LLM OPT: sdp attn_weights = None if not self.training and not hidden_states.requires_grad and \ use_flash_attention(query_states, key_states, attention_mask): @@ -211,207 +169,56 @@ def baichuan_attention_forward_13b( past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if use_quantize_kv_cache(self.W_pack, hidden_states): - forward_function = baichuan_attention_forward_13b_quantized - else: - forward_function = baichuan_attention_forward_13b_origin - return forward_function( - self=self, - hidden_states=hidden_states, - attention_mask=attention_mask, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - - -def baichuan_attention_forward_13b_quantized( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() device = hidden_states.device - proj = self.W_pack(hidden_states) - proj = ( - proj.unflatten(-1, (3, self.hidden_size)) - .unsqueeze(0) - .transpose(0, -2) - .squeeze(-2) - ) - query_states = ( - proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - ) - key_states = ( - proj[1].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - ) - value_states = ( - proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - ) + qkv = self.W_pack(hidden_states) + qkv = qkv.view(bsz, q_len, self.num_heads * 3, self.head_dim) + qkv = qkv.transpose(1, 2) + query_states, key_states, value_states = qkv.split([self.num_heads, + self.num_heads, + self.num_heads], dim=1) - if past_key_value is None: - kv_seq_len = key_states.shape[-2] - k_cache, v_cache = init_fp8_kv_cache( - bsz, self.num_heads, kv_seq_len, self.head_dim, - device=device - ) - else: - k_cache, v_cache = past_key_value - key_states, value_states = append_fp8_kv_cache(k_cache, v_cache, - key_states, value_states) - past_key_value = (key_states, value_states) - - if query_states.size(2) != 1 or device.type != 'xpu': - key_states, value_states = restore_fp8_kv_cache(key_states, value_states, - query_states.dtype) - - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) - else: - import linear_q4_0 - attn_weights = linear_q4_0.query_key_fp8_matmul(query_states, key_states) - - attn_weights = attn_weights / math.sqrt(self.head_dim) - - if attention_mask is not None: - if q_len == 1: # inference with cache - if len(attention_mask.size()) == 4: - attention_mask = attention_mask[:, :, -1:, :] - else: - attention_mask = attention_mask[:, -1:, :] - attn_weights = attn_weights + attention_mask - attn_weights = torch.max(attn_weights, - torch.tensor(torch.finfo(attn_weights.dtype).min)) - - attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) - attn_weights = attn_weights.to(hidden_states.dtype) - - if query_states.size(2) != 1 or device.type != 'xpu': - attn_output = torch.matmul(attn_weights, value_states) - else: - import linear_q4_0 - attn_output = linear_q4_0.attn_value_fp8_matmul(attn_weights, value_states) - - attn_output = attn_output.transpose(1, 2) - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -def baichuan_attention_forward_13b_origin( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - device = hidden_states.device - - proj = self.W_pack(hidden_states) - proj = ( - proj.unflatten(-1, (3, self.hidden_size)) - .unsqueeze(0) - .transpose(0, -2) - .squeeze(-2) - ) - query_states = ( - proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - ) - key_states = ( - proj[1].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - ) - value_states = ( - proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - ) - - kv_seq_len = key_states.shape[-2] - enough_kv_room = True + kv_seq_len = key_states.shape[2] if past_key_value is not None: - enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value, seq_len=kv_seq_len) - kv_seq_len += past_key_value[0].shape[-2] - - # if past_key_value is not None: - # # reuse k, v, self_attention - # key_states = torch.cat([past_key_value[0], key_states], dim=2) - # value_states = torch.cat([past_key_value[1], value_states], dim=2) - if past_key_value is not None: - # reuse k, v, self_attention - cache_k = past_key_value[0] - cache_v = past_key_value[1] - if not enough_kv_room: - if device.type == 'xpu': - torch.xpu.empty_cache() - # allocate new - new_cache_k, new_cache_v = extend_kv_cache(bsz, - self.num_heads, - self.head_dim, - cache_k.size(2), - kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH, - dtype=cache_k.dtype, - device=device) - new_cache_k[:] = cache_k - new_cache_v[:] = cache_v - cache_k = new_cache_k - cache_v = new_cache_v - - key_states, value_states = append_kv_cache(cache_k, cache_v, key_states, value_states) - - elif use_cache: - max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH - new_key_states, new_value_states = init_kv_cache(bsz, - self.num_heads, - self.head_dim, - kv_seq_len, - max_cache_length, - dtype=key_states.dtype, - device=device) - new_key_states[:] = key_states - new_value_states[:] = value_states - key_states = new_key_states - value_states = new_value_states + kv_seq_len += past_key_value[0].shape[2] + # IPEX-LLM OPT: kv cache and quantize kv + use_quantize_kv = use_quantize_kv_cache(self.W_pack, hidden_states) + key_states, value_states = update_past_key_value( + past_key_value, key_states, value_states, + kv_seq_len, use_quantize_kv, device + ) past_key_value = (key_states, value_states) if use_cache else None if self.training: warnings.warn("xops is not supported on Intel GPU, so just use normal implementation") - attn_weights = torch.matmul( - query_states.to(dtype=key_states.dtype), key_states.transpose(2, 3) - ) / math.sqrt(self.head_dim) - if attention_mask is not None: - if q_len == 1: # inference with cache - if len(attention_mask.size()) == 4: - attention_mask = attention_mask[:, :, -1:, :] - else: - attention_mask = attention_mask[:, -1:, :] - if attention_mask.shape[-2] == attn_weights.shape[-2]: - attn_weights = attn_weights + attention_mask + if len(attention_mask.size()) == 4: + attention_mask = attention_mask[:, :, -q_len:, :] else: - # support for Baichuan/Baichuan2 13B Chat running speculative decoding - # split attention mask on dim -2 - split_sizes = [attention_mask.shape[-2] - attn_weights.shape[-2], - attn_weights.shape[-2]] - # the last chunk of splited is the new attention mask - attention_mask = attention_mask.split(split_sizes, dim=-2)[-1] - attn_weights = attn_weights + attention_mask - attn_weights = torch.max( - attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min) - ) + attention_mask = attention_mask[:, None, -q_len:, :] + if use_quantize_kv and q_len == 1: + import linear_q4_0 + attn_weights = linear_q4_0.query_key_fp8_matmul(query_states, key_states) + else: + if use_quantize_kv: + key_states, value_states = restore_fp8_kv_cache(key_states, value_states, + query_states.dtype) + attn_weights = torch.matmul(query_states, + key_states.transpose(2, 3)) + attn_weights = attn_weights / math.sqrt(self.head_dim) + if attention_mask is not None: + attn_weights = attn_weights + attention_mask attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) - attn_output = torch.matmul(attn_weights.to(dtype=value_states.dtype), value_states) - + if use_quantize_kv and q_len == 1: + import linear_q4_0 + attn_output = linear_q4_0.attn_value_fp8_matmul(attn_weights, value_states) + else: + attn_output = torch.matmul(attn_weights.to(dtype=value_states.dtype), value_states) attn_output = attn_output.transpose(1, 2) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) diff --git a/python/llm/src/ipex_llm/transformers/models/utils.py b/python/llm/src/ipex_llm/transformers/models/utils.py index 58af7a9c..dcb6de3f 100644 --- a/python/llm/src/ipex_llm/transformers/models/utils.py +++ b/python/llm/src/ipex_llm/transformers/models/utils.py @@ -24,6 +24,7 @@ from ipex_llm.transformers.low_bit_linear import SYM_INT4, SYM_INT8, FP8E5, IQ2_ from ipex_llm.transformers.convert import is_deepspeed_available FP8_KV_ALLOC_LENGTH = 512 +KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)) # used in fused mlp forward SILU = 0 @@ -426,3 +427,49 @@ def fp16_fusion_check(proj, x, training): if device_type != "pvc": return False return True + + +def update_past_key_value(past_key_value, key_states, value_states, + kv_seq_len, use_quantize_kv, device): + bsz, num_heads, _, head_dim = key_states.shape + if use_quantize_kv: + if past_key_value is None: + k_cache, v_cache = init_fp8_kv_cache( + bsz, num_heads, kv_seq_len, head_dim, + device=device + ) + else: + k_cache, v_cache = past_key_value + key_states, value_states = append_fp8_kv_cache(k_cache, v_cache, + key_states, value_states) + else: + if past_key_value is None: + max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH + k_cache, v_cache = init_kv_cache(bsz, + num_heads, + head_dim, + kv_seq_len, + max_cache_length, + dtype=key_states.dtype, + device=device) + k_cache[...] = key_states + v_cache[...] = value_states + key_states = k_cache + value_states = v_cache + else: + k_cache, v_cache = past_key_value + if k_cache.stride(1) < kv_seq_len * k_cache.size(3): + max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH + new_k_cache, new_v_cache = extend_kv_cache(bsz, + num_heads, + head_dim, + k_cache.size(2), + max_cache_length, + dtype=k_cache.dtype, + device=device) + new_k_cache[...] = k_cache + new_v_cache[...] = v_cache + k_cache = new_k_cache + v_cache = new_v_cache + key_states, value_states = append_kv_cache(k_cache, v_cache, key_states, value_states) + return key_states, value_states