diff --git a/python/llm/src/ipex_llm/transformers/kv.py b/python/llm/src/ipex_llm/transformers/kv.py index 18acf17b..d4e9d9d4 100644 --- a/python/llm/src/ipex_llm/transformers/kv.py +++ b/python/llm/src/ipex_llm/transformers/kv.py @@ -32,7 +32,6 @@ class DynamicFp8Cache(DynamicCache): value_states: torch.Tensor, layer_idx: int, cache_kwargs: Optional[Dict[str, Any]]=None, - new_layout=False, ) -> Tuple[torch.Tensor, torch.Tensor]: batch_size, num_heads, seq_len, head_dim = key_states.shape @@ -50,18 +49,15 @@ class DynamicFp8Cache(DynamicCache): k_cache, v_cache = init_fp8_kv_cache( batch_size, num_heads, seq_len, head_dim, device=key_states.device, - new_layout=new_layout, ) - k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache, key_states, value_states, - new_layout=new_layout) + k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache, key_states, value_states) self.key_cache.append(k_cache) self.value_cache.append(v_cache) else: k_cache = self.key_cache[layer_idx] v_cache = self.value_cache[layer_idx] - k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache, key_states, value_states, - new_layout=new_layout) + k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache, key_states, value_states) self.key_cache[layer_idx] = k_cache self.value_cache[layer_idx] = v_cache @@ -77,7 +73,6 @@ class DynamicNormalCache(DynamicCache): value_states: torch.Tensor, layer_idx: int, cache_kwargs: Optional[Dict[str, Any]]=None, - new_layout=False, # useless, just keep same as DynamicFp8Cache ) -> Tuple[torch.Tensor, torch.Tensor]: batch_size, num_heads, seq_len, head_dim = key_states.shape diff --git a/python/llm/src/ipex_llm/transformers/models/baichuan.py b/python/llm/src/ipex_llm/transformers/models/baichuan.py index b25964fe..8e54ee55 100644 --- a/python/llm/src/ipex_llm/transformers/models/baichuan.py +++ b/python/llm/src/ipex_llm/transformers/models/baichuan.py @@ -128,15 +128,15 @@ def baichuan_attention_forward_7b_quantized( if use_cache: k_cache, v_cache = init_fp8_kv_cache( bsz, self.num_heads, kv_seq_len, self.head_dim, - device=device, new_layout=True + device=device ) key_states, value_states = append_fp8_kv_cache(k_cache, v_cache, key_states, - value_states, new_layout=True) + value_states) past_key_value = (key_states, value_states) else: k_cache, v_cache = past_key_value key_states, value_states = append_fp8_kv_cache(k_cache, v_cache, - key_states, value_states, new_layout=True) + key_states, value_states) kv_seq_len = key_states.shape[-2] past_key_value = (key_states, value_states) if query_states.size(2) != 1 or query_states.device.type != 'xpu': diff --git a/python/llm/src/ipex_llm/transformers/models/baichuan2.py b/python/llm/src/ipex_llm/transformers/models/baichuan2.py index a5848e6e..550035c3 100644 --- a/python/llm/src/ipex_llm/transformers/models/baichuan2.py +++ b/python/llm/src/ipex_llm/transformers/models/baichuan2.py @@ -142,12 +142,12 @@ def baichuan_attention_forward_7b_quantized( kv_seq_len = key_states.shape[-2] k_cache, v_cache = init_fp8_kv_cache( bsz, self.num_heads, kv_seq_len, self.head_dim, - device=device, new_layout=True + device=device ) else: k_cache, v_cache = past_key_value key_states, value_states = append_fp8_kv_cache(k_cache, v_cache, - key_states, value_states, new_layout=True) + key_states, value_states) past_key_value = (key_states, value_states) if use_cache else None diff --git a/python/llm/src/ipex_llm/transformers/models/chatglm2.py b/python/llm/src/ipex_llm/transformers/models/chatglm2.py index ed39c00e..9032bbf5 100644 --- a/python/llm/src/ipex_llm/transformers/models/chatglm2.py +++ b/python/llm/src/ipex_llm/transformers/models/chatglm2.py @@ -280,8 +280,7 @@ def chatglm2_quantized_attention_forward_8eb45c( n_kv_head, seq_len, head_dim, - query_layer.device, - new_layout=True) + query_layer.device) k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache, key_layer, value_layer) else: k_cache, v_cache = kv_cache @@ -289,8 +288,7 @@ def chatglm2_quantized_attention_forward_8eb45c( v_cache = v_cache.permute(1, 2, 0, 3) # k_cache, v_cache's shape: [bs, n_kv_head, seq_len, head_dim] - k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache, key_layer, value_layer, - new_layout=True) + k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache, key_layer, value_layer) if attention_mask is not None: attention_mask = ~attention_mask diff --git a/python/llm/src/ipex_llm/transformers/models/llama.py b/python/llm/src/ipex_llm/transformers/models/llama.py index 45775eac..703d3163 100644 --- a/python/llm/src/ipex_llm/transformers/models/llama.py +++ b/python/llm/src/ipex_llm/transformers/models/llama.py @@ -438,7 +438,7 @@ def llama_attention_forward_4_31_quantized( if use_cache: k_cache, v_cache = init_fp8_kv_cache( bsz, self.num_key_value_heads, kv_seq_len, self.head_dim, - device=query_states.device, new_layout=True + device=query_states.device ) key_states, value_states = append_fp8_kv_cache(k_cache, v_cache, key_states, value_states) @@ -446,7 +446,7 @@ def llama_attention_forward_4_31_quantized( else: k_cache, v_cache = past_key_value key_states, value_states = append_fp8_kv_cache(k_cache, v_cache, - key_states, value_states, new_layout=True) + key_states, value_states) kv_seq_len = key_states.shape[-2] past_key_value = (key_states, value_states) @@ -1067,13 +1067,11 @@ def llama_attention_forward_4_36_quantized( if use_cache: cache_kwargs = None key_states, value_states = past_key_value.update(key_states, value_states, - self.layer_idx, cache_kwargs, - new_layout=True) + self.layer_idx, cache_kwargs) else: cache_kwargs = None # Specific to RoPE models key_states, value_states = past_key_value.update(key_states, value_states, - self.layer_idx, cache_kwargs, - new_layout=True) + self.layer_idx, cache_kwargs) kv_seq_len = key_states.shape[-2] if not use_sdp_fp8(q_len, key_states.shape[2], query_states): key_states, value_states = restore_fp8_kv_cache(key_states, value_states, diff --git a/python/llm/src/ipex_llm/transformers/models/mistral.py b/python/llm/src/ipex_llm/transformers/models/mistral.py index fd7509db..10791a06 100644 --- a/python/llm/src/ipex_llm/transformers/models/mistral.py +++ b/python/llm/src/ipex_llm/transformers/models/mistral.py @@ -299,7 +299,7 @@ def mistral_attention_forward_quantized( if use_cache: k_cache, v_cache = init_fp8_kv_cache( bsz, self.num_heads, kv_seq_len, self.head_dim, - device=query_states.device, new_layout=True + device=query_states.device ) key_states, value_states = append_fp8_kv_cache(k_cache, v_cache, key_states, value_states) @@ -307,7 +307,7 @@ def mistral_attention_forward_quantized( else: k_cache, v_cache = past_key_value key_states, value_states = append_fp8_kv_cache(k_cache, v_cache, - key_states, value_states, new_layout=True) + key_states, value_states) kv_seq_len = key_states.shape[-2] past_key_value = (key_states, value_states) @@ -680,13 +680,11 @@ def mistral_attention_forward_4_36_quantized( if use_cache: cache_kwargs = None key_states, value_states = past_key_value.update(key_states, value_states, - self.layer_idx, cache_kwargs, - new_layout=True) + self.layer_idx, cache_kwargs) else: cache_kwargs = None # Specific to RoPE models key_states, value_states = past_key_value.update(key_states, value_states, - self.layer_idx, cache_kwargs, - new_layout=True) + self.layer_idx, cache_kwargs) kv_seq_len = key_states.shape[-2] if not use_sdp_fp8(q_len, key_states.shape[2], query_states): key_states, value_states = restore_fp8_kv_cache(key_states, value_states, diff --git a/python/llm/src/ipex_llm/transformers/models/phi.py b/python/llm/src/ipex_llm/transformers/models/phi.py index 43a63827..9a5a01a5 100644 --- a/python/llm/src/ipex_llm/transformers/models/phi.py +++ b/python/llm/src/ipex_llm/transformers/models/phi.py @@ -129,7 +129,7 @@ def attention_forward( invalidInputError(past_key_value is not None, "`past_key_value` cannot be None") key_states, value_states = past_key_value.update(key_states, value_states, - self.layer_idx, None, new_layout=True) + self.layer_idx, None) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) diff --git a/python/llm/src/ipex_llm/transformers/models/qwen.py b/python/llm/src/ipex_llm/transformers/models/qwen.py index ef3da70d..7405d552 100644 --- a/python/llm/src/ipex_llm/transformers/models/qwen.py +++ b/python/llm/src/ipex_llm/transformers/models/qwen.py @@ -446,7 +446,7 @@ def qwen_attention_forward_quantized( max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH k_cache, v_cache = init_fp8_kv_cache( query.size(0), self.num_heads, kv_seq_len, self.head_dim, - device=query.device, new_layout=True + device=query.device ) key, value = append_fp8_kv_cache(k_cache, v_cache, key, value) else: @@ -461,7 +461,7 @@ def qwen_attention_forward_quantized( v_cache = v_cache.transpose(1, 2) # k_cache and v_cache's shape: [bs, num_heads, context_length, head_dim] - key, value = append_fp8_kv_cache(k_cache, v_cache, key, value, new_layout=True) + key, value = append_fp8_kv_cache(k_cache, v_cache, key, value) attn_output, attn_weight = core_attn( self, query, key, value, causal_mask, attention_mask, head_mask diff --git a/python/llm/src/ipex_llm/transformers/models/qwen2.py b/python/llm/src/ipex_llm/transformers/models/qwen2.py index 9cf7a640..7494e617 100644 --- a/python/llm/src/ipex_llm/transformers/models/qwen2.py +++ b/python/llm/src/ipex_llm/transformers/models/qwen2.py @@ -358,8 +358,7 @@ def qwen2_attention_forward_quantized( if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models key_states, value_states = past_key_value.update(key_states, value_states, - self.layer_idx, cache_kwargs, - new_layout=True) + self.layer_idx, cache_kwargs) if q_len == 1 and query_states.device.type == 'xpu' and not self.training \ and not hidden_states.requires_grad: diff --git a/python/llm/src/ipex_llm/transformers/models/starcoder2.py b/python/llm/src/ipex_llm/transformers/models/starcoder2.py index c8e23eac..014bd51d 100644 --- a/python/llm/src/ipex_llm/transformers/models/starcoder2.py +++ b/python/llm/src/ipex_llm/transformers/models/starcoder2.py @@ -132,7 +132,7 @@ def attention_forward( use_quantize_kv = use_quantize_kv_cache(self.o_proj, hidden_states) key_states, value_states = past_key_value.update(key_states, value_states, - self.layer_idx, None, new_layout=True) + self.layer_idx, None) if use_quantize_kv and q_len == 1: import linear_q4_0 diff --git a/python/llm/src/ipex_llm/transformers/models/utils.py b/python/llm/src/ipex_llm/transformers/models/utils.py index 3691886b..e2c84488 100644 --- a/python/llm/src/ipex_llm/transformers/models/utils.py +++ b/python/llm/src/ipex_llm/transformers/models/utils.py @@ -96,7 +96,7 @@ def kv_cache_device_check(x: torch.Tensor) -> bool: 1 < x.size(0) and x.size(0) <= 8) -def init_fp8_kv_cache(batch_size, num_heads, current_length, head_dim, device, new_layout=False): +def init_fp8_kv_cache(batch_size, num_heads, current_length, head_dim, device): max_length = current_length + FP8_KV_ALLOC_LENGTH k_cache_storage = torch.empty(batch_size, num_heads, max_length, head_dim, @@ -104,7 +104,6 @@ def init_fp8_kv_cache(batch_size, num_heads, current_length, head_dim, device, n k_cache = k_cache_storage.as_strided((batch_size, num_heads, 0, head_dim), k_cache_storage.stride(), storage_offset=0) - # ignore `new_layout`, will remove it in next PR v_cache_storage = torch.empty(batch_size, num_heads, max_length, head_dim, dtype=torch.uint8, device=device) v_cache = v_cache_storage.as_strided((batch_size, num_heads, 0, head_dim), @@ -112,14 +111,14 @@ def init_fp8_kv_cache(batch_size, num_heads, current_length, head_dim, device, n return k_cache, v_cache -def append_fp8_kv_cache(k_cache, v_cache, key, value, new_layout=False): +def append_fp8_kv_cache(k_cache, v_cache, key, value): batch_size, num_heads, cur_length, head_dim = k_cache.shape new_length = cur_length + key.size(2) new_size = (batch_size, num_heads, new_length, head_dim) if k_cache.stride(1) < new_length * k_cache.size(3): new_k_cache, new_v_cache = init_fp8_kv_cache(batch_size, num_heads, new_length, - head_dim, key.device, new_layout) + head_dim, key.device) new_k_cache = new_k_cache.as_strided(new_size, new_k_cache.stride(), storage_offset=0) new_v_cache = new_v_cache.as_strided(new_size, new_v_cache.stride(), storage_offset=0) new_k_cache[:, :, :cur_length, :] = k_cache