parent
							
								
									3ee6dec0f8
								
							
						
					
					
						commit
						6693e8ab04
					
				
					 3 changed files with 97 additions and 4 deletions
				
			
		| 
						 | 
					@ -22,7 +22,8 @@ import math
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from .models.utils import (
 | 
					from .models.utils import (
 | 
				
			||||||
    init_fp8_kv_cache, append_fp8_kv_cache,
 | 
					    init_fp8_kv_cache, append_fp8_kv_cache,
 | 
				
			||||||
    init_kv_cache, append_kv_cache, extend_kv_cache
 | 
					    init_kv_cache, append_kv_cache, extend_kv_cache,
 | 
				
			||||||
 | 
					    init_unbalanced_fp8_kv_cache, append_unbalanced_fp8_kv_cache,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
from typing import Optional, Dict, Tuple, Any, List
 | 
					from typing import Optional, Dict, Tuple, Any, List
 | 
				
			||||||
from transformers.cache_utils import DynamicCache
 | 
					from transformers.cache_utils import DynamicCache
 | 
				
			||||||
| 
						 | 
					@ -151,6 +152,55 @@ class DynamicNormalCache(DynamicCache):
 | 
				
			||||||
        return past_key_values
 | 
					        return past_key_values
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class DynamicUnbalancedFp8Cache(DynamicCache):
 | 
				
			||||||
 | 
					    def __init__(self, num_hidden_layers: Optional[int] = None) -> None:
 | 
				
			||||||
 | 
					        # ignore num_hidden_layers to fix transformers >= 4.45
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def update(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        key_states: torch.Tensor,
 | 
				
			||||||
 | 
					        value_states: torch.Tensor,
 | 
				
			||||||
 | 
					        layer_idx: int,
 | 
				
			||||||
 | 
					        cache_kwargs: Optional[Dict[str, Any]]=None,
 | 
				
			||||||
 | 
					    ) -> Tuple[torch.Tensor, torch.Tensor]:
 | 
				
			||||||
 | 
					        # fix converting empty DynamicCache in transformers >= 4.45
 | 
				
			||||||
 | 
					        if key_states == []:
 | 
				
			||||||
 | 
					            return key_states, value_states
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        batch_size, num_heads, seq_len, k_head_dim = key_states.shape
 | 
				
			||||||
 | 
					        _, _, _, v_head_dim = value_states.shape
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if layer_idx == 0:
 | 
				
			||||||
 | 
					            if hasattr(self, "_seen_tokens"):
 | 
				
			||||||
 | 
					                # 4.39 uses `_seen_tokens`
 | 
				
			||||||
 | 
					                self._seen_tokens += seq_len
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                # 4.37 uses `seen_tokens`
 | 
				
			||||||
 | 
					                self.seen_tokens += seq_len
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Update the cache
 | 
				
			||||||
 | 
					        if len(self.key_cache) <= layer_idx:
 | 
				
			||||||
 | 
					            k_cache, v_cache = init_unbalanced_fp8_kv_cache(
 | 
				
			||||||
 | 
					                batch_size, num_heads, seq_len, k_head_dim, v_head_dim,
 | 
				
			||||||
 | 
					                device=key_states.device,
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            k_cache, v_cache = append_unbalanced_fp8_kv_cache(k_cache, v_cache,
 | 
				
			||||||
 | 
					                                                              key_states, value_states)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            self.key_cache.append(k_cache)
 | 
				
			||||||
 | 
					            self.value_cache.append(v_cache)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            k_cache = self.key_cache[layer_idx]
 | 
				
			||||||
 | 
					            v_cache = self.value_cache[layer_idx]
 | 
				
			||||||
 | 
					            k_cache, v_cache = append_unbalanced_fp8_kv_cache(k_cache, v_cache,
 | 
				
			||||||
 | 
					                                                              key_states, value_states)
 | 
				
			||||||
 | 
					            self.key_cache[layer_idx] = k_cache
 | 
				
			||||||
 | 
					            self.value_cache[layer_idx] = v_cache
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return self.key_cache[layer_idx], self.value_cache[layer_idx]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Copied from transformers.models.llama.modeling_llama.repeat_kv
 | 
					# Copied from transformers.models.llama.modeling_llama.repeat_kv
 | 
				
			||||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
 | 
					def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -273,11 +273,11 @@ def scaled_dot_product_attention(query: torch.Tensor, key: torch.Tensor,
 | 
				
			||||||
            else:
 | 
					            else:
 | 
				
			||||||
                attn_output = xe_addons.sdp_causal(query, key, value, mask, scale)
 | 
					                attn_output = xe_addons.sdp_causal(query, key, value, mask, scale)
 | 
				
			||||||
        elif seq_length != kv_length and seq_length <= 32:
 | 
					        elif seq_length != kv_length and seq_length <= 32:
 | 
				
			||||||
            # todo: add scale support
 | 
					            # todo: add further scale support
 | 
				
			||||||
            if key.dtype == torch.uint8:
 | 
					            if key.dtype == torch.uint8:
 | 
				
			||||||
                attn_output = xe_addons.sdp_fp8(query, key, value, mask)
 | 
					                attn_output = xe_addons.sdp_fp8(query, key, value, mask, scale)
 | 
				
			||||||
            else:
 | 
					            else:
 | 
				
			||||||
                attn_output = xe_addons.sdp(query, key, value, mask)
 | 
					                attn_output = xe_addons.sdp(query, key, value, mask, scale)
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            if key.dtype == torch.uint8:
 | 
					            if key.dtype == torch.uint8:
 | 
				
			||||||
                attn_output = xe_addons.sdp_fp8_non_causal(query, key, value, mask, scale)
 | 
					                attn_output = xe_addons.sdp_fp8_non_causal(query, key, value, mask, scale)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -138,6 +138,49 @@ def append_fp8_kv_cache(k_cache, v_cache, key, value):
 | 
				
			||||||
    return new_k_cache, new_v_cache
 | 
					    return new_k_cache, new_v_cache
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def init_unbalanced_fp8_kv_cache(batch_size, num_heads, current_length,
 | 
				
			||||||
 | 
					                                 k_head_dim, v_head_dim, device):
 | 
				
			||||||
 | 
					    # for case which k head dim is different from v head dim
 | 
				
			||||||
 | 
					    max_length = current_length + FP8_KV_ALLOC_LENGTH
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    k_cache_storage = torch.empty(batch_size, num_heads, max_length, k_head_dim,
 | 
				
			||||||
 | 
					                                  dtype=torch.uint8, device=device)
 | 
				
			||||||
 | 
					    k_cache = k_cache_storage.as_strided((batch_size, num_heads, 0, k_head_dim),
 | 
				
			||||||
 | 
					                                         k_cache_storage.stride(), storage_offset=0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    v_cache_storage = torch.empty(batch_size, num_heads, max_length, v_head_dim,
 | 
				
			||||||
 | 
					                                  dtype=torch.uint8, device=device)
 | 
				
			||||||
 | 
					    v_cache = v_cache_storage.as_strided((batch_size, num_heads, 0, v_head_dim),
 | 
				
			||||||
 | 
					                                         v_cache_storage.stride(), storage_offset=0)
 | 
				
			||||||
 | 
					    return k_cache, v_cache
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def append_unbalanced_fp8_kv_cache(k_cache, v_cache, key, value):
 | 
				
			||||||
 | 
					    batch_size, num_heads, cur_length, k_head_dim = k_cache.shape
 | 
				
			||||||
 | 
					    _, _, _, v_head_dim = v_cache.shape
 | 
				
			||||||
 | 
					    new_length = cur_length + key.size(2)
 | 
				
			||||||
 | 
					    new_k_size = (batch_size, num_heads, new_length, k_head_dim)
 | 
				
			||||||
 | 
					    new_v_size = (batch_size, num_heads, new_length, v_head_dim)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if k_cache.stride(1) < new_length * k_cache.size(3):
 | 
				
			||||||
 | 
					        new_k_cache, new_v_cache = init_unbalanced_fp8_kv_cache(batch_size, num_heads, new_length,
 | 
				
			||||||
 | 
					                                                                k_head_dim, v_head_dim, key.device)
 | 
				
			||||||
 | 
					        new_k_cache = new_k_cache.as_strided(new_k_size, new_k_cache.stride(), storage_offset=0)
 | 
				
			||||||
 | 
					        new_v_cache = new_v_cache.as_strided(new_v_size, new_v_cache.stride(), storage_offset=0)
 | 
				
			||||||
 | 
					        new_k_cache[:, :, :cur_length, :] = k_cache
 | 
				
			||||||
 | 
					        new_v_cache[:, :, :cur_length, :] = v_cache
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        new_k_cache = k_cache.as_strided(new_k_size, k_cache.stride(), storage_offset=0)
 | 
				
			||||||
 | 
					        new_v_cache = v_cache.as_strided(new_v_size, v_cache.stride(), storage_offset=0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    import xe_addons
 | 
				
			||||||
 | 
					    xe_addons.quantize_key_value(key, value,
 | 
				
			||||||
 | 
					                                 new_k_cache[:, :, cur_length:new_length, :],
 | 
				
			||||||
 | 
					                                 new_v_cache[:, :, cur_length:new_length, :])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return new_k_cache, new_v_cache
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def restore_fp8_kv_cache(k_cache, v_cache, dtype):
 | 
					def restore_fp8_kv_cache(k_cache, v_cache, dtype):
 | 
				
			||||||
    key_states = torch.empty(k_cache.shape, device=k_cache.device, dtype=dtype)
 | 
					    key_states = torch.empty(k_cache.shape, device=k_cache.device, dtype=dtype)
 | 
				
			||||||
    value_states = torch.empty(v_cache.shape, device=v_cache.device, dtype=dtype)
 | 
					    value_states = torch.empty(v_cache.shape, device=v_cache.device, dtype=dtype)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue