remove new_layout parameter (#10906)

This commit is contained in:
Yishuo Wang 2024-04-29 10:31:50 +08:00 committed by GitHub
parent fbcd7bc737
commit d884c62dc4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 25 additions and 38 deletions

View file

@ -32,7 +32,6 @@ class DynamicFp8Cache(DynamicCache):
value_states: torch.Tensor, value_states: torch.Tensor,
layer_idx: int, layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]]=None, cache_kwargs: Optional[Dict[str, Any]]=None,
new_layout=False,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size, num_heads, seq_len, head_dim = key_states.shape batch_size, num_heads, seq_len, head_dim = key_states.shape
@ -50,18 +49,15 @@ class DynamicFp8Cache(DynamicCache):
k_cache, v_cache = init_fp8_kv_cache( k_cache, v_cache = init_fp8_kv_cache(
batch_size, num_heads, seq_len, head_dim, batch_size, num_heads, seq_len, head_dim,
device=key_states.device, device=key_states.device,
new_layout=new_layout,
) )
k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache, key_states, value_states, k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache, key_states, value_states)
new_layout=new_layout)
self.key_cache.append(k_cache) self.key_cache.append(k_cache)
self.value_cache.append(v_cache) self.value_cache.append(v_cache)
else: else:
k_cache = self.key_cache[layer_idx] k_cache = self.key_cache[layer_idx]
v_cache = self.value_cache[layer_idx] v_cache = self.value_cache[layer_idx]
k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache, key_states, value_states, k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache, key_states, value_states)
new_layout=new_layout)
self.key_cache[layer_idx] = k_cache self.key_cache[layer_idx] = k_cache
self.value_cache[layer_idx] = v_cache self.value_cache[layer_idx] = v_cache
@ -77,7 +73,6 @@ class DynamicNormalCache(DynamicCache):
value_states: torch.Tensor, value_states: torch.Tensor,
layer_idx: int, layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]]=None, cache_kwargs: Optional[Dict[str, Any]]=None,
new_layout=False, # useless, just keep same as DynamicFp8Cache
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size, num_heads, seq_len, head_dim = key_states.shape batch_size, num_heads, seq_len, head_dim = key_states.shape

View file

@ -128,15 +128,15 @@ def baichuan_attention_forward_7b_quantized(
if use_cache: if use_cache:
k_cache, v_cache = init_fp8_kv_cache( k_cache, v_cache = init_fp8_kv_cache(
bsz, self.num_heads, kv_seq_len, self.head_dim, bsz, self.num_heads, kv_seq_len, self.head_dim,
device=device, new_layout=True device=device
) )
key_states, value_states = append_fp8_kv_cache(k_cache, v_cache, key_states, key_states, value_states = append_fp8_kv_cache(k_cache, v_cache, key_states,
value_states, new_layout=True) value_states)
past_key_value = (key_states, value_states) past_key_value = (key_states, value_states)
else: else:
k_cache, v_cache = past_key_value k_cache, v_cache = past_key_value
key_states, value_states = append_fp8_kv_cache(k_cache, v_cache, key_states, value_states = append_fp8_kv_cache(k_cache, v_cache,
key_states, value_states, new_layout=True) key_states, value_states)
kv_seq_len = key_states.shape[-2] kv_seq_len = key_states.shape[-2]
past_key_value = (key_states, value_states) past_key_value = (key_states, value_states)
if query_states.size(2) != 1 or query_states.device.type != 'xpu': if query_states.size(2) != 1 or query_states.device.type != 'xpu':

View file

@ -142,12 +142,12 @@ def baichuan_attention_forward_7b_quantized(
kv_seq_len = key_states.shape[-2] kv_seq_len = key_states.shape[-2]
k_cache, v_cache = init_fp8_kv_cache( k_cache, v_cache = init_fp8_kv_cache(
bsz, self.num_heads, kv_seq_len, self.head_dim, bsz, self.num_heads, kv_seq_len, self.head_dim,
device=device, new_layout=True device=device
) )
else: else:
k_cache, v_cache = past_key_value k_cache, v_cache = past_key_value
key_states, value_states = append_fp8_kv_cache(k_cache, v_cache, key_states, value_states = append_fp8_kv_cache(k_cache, v_cache,
key_states, value_states, new_layout=True) key_states, value_states)
past_key_value = (key_states, value_states) if use_cache else None past_key_value = (key_states, value_states) if use_cache else None

View file

@ -280,8 +280,7 @@ def chatglm2_quantized_attention_forward_8eb45c(
n_kv_head, n_kv_head,
seq_len, seq_len,
head_dim, head_dim,
query_layer.device, query_layer.device)
new_layout=True)
k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache, key_layer, value_layer) k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache, key_layer, value_layer)
else: else:
k_cache, v_cache = kv_cache k_cache, v_cache = kv_cache
@ -289,8 +288,7 @@ def chatglm2_quantized_attention_forward_8eb45c(
v_cache = v_cache.permute(1, 2, 0, 3) v_cache = v_cache.permute(1, 2, 0, 3)
# k_cache, v_cache's shape: [bs, n_kv_head, seq_len, head_dim] # k_cache, v_cache's shape: [bs, n_kv_head, seq_len, head_dim]
k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache, key_layer, value_layer, k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache, key_layer, value_layer)
new_layout=True)
if attention_mask is not None: if attention_mask is not None:
attention_mask = ~attention_mask attention_mask = ~attention_mask

View file

@ -438,7 +438,7 @@ def llama_attention_forward_4_31_quantized(
if use_cache: if use_cache:
k_cache, v_cache = init_fp8_kv_cache( k_cache, v_cache = init_fp8_kv_cache(
bsz, self.num_key_value_heads, kv_seq_len, self.head_dim, bsz, self.num_key_value_heads, kv_seq_len, self.head_dim,
device=query_states.device, new_layout=True device=query_states.device
) )
key_states, value_states = append_fp8_kv_cache(k_cache, v_cache, key_states, value_states = append_fp8_kv_cache(k_cache, v_cache,
key_states, value_states) key_states, value_states)
@ -446,7 +446,7 @@ def llama_attention_forward_4_31_quantized(
else: else:
k_cache, v_cache = past_key_value k_cache, v_cache = past_key_value
key_states, value_states = append_fp8_kv_cache(k_cache, v_cache, key_states, value_states = append_fp8_kv_cache(k_cache, v_cache,
key_states, value_states, new_layout=True) key_states, value_states)
kv_seq_len = key_states.shape[-2] kv_seq_len = key_states.shape[-2]
past_key_value = (key_states, value_states) past_key_value = (key_states, value_states)
@ -1067,13 +1067,11 @@ def llama_attention_forward_4_36_quantized(
if use_cache: if use_cache:
cache_kwargs = None cache_kwargs = None
key_states, value_states = past_key_value.update(key_states, value_states, key_states, value_states = past_key_value.update(key_states, value_states,
self.layer_idx, cache_kwargs, self.layer_idx, cache_kwargs)
new_layout=True)
else: else:
cache_kwargs = None # Specific to RoPE models cache_kwargs = None # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, key_states, value_states = past_key_value.update(key_states, value_states,
self.layer_idx, cache_kwargs, self.layer_idx, cache_kwargs)
new_layout=True)
kv_seq_len = key_states.shape[-2] kv_seq_len = key_states.shape[-2]
if not use_sdp_fp8(q_len, key_states.shape[2], query_states): if not use_sdp_fp8(q_len, key_states.shape[2], query_states):
key_states, value_states = restore_fp8_kv_cache(key_states, value_states, key_states, value_states = restore_fp8_kv_cache(key_states, value_states,

View file

@ -299,7 +299,7 @@ def mistral_attention_forward_quantized(
if use_cache: if use_cache:
k_cache, v_cache = init_fp8_kv_cache( k_cache, v_cache = init_fp8_kv_cache(
bsz, self.num_heads, kv_seq_len, self.head_dim, bsz, self.num_heads, kv_seq_len, self.head_dim,
device=query_states.device, new_layout=True device=query_states.device
) )
key_states, value_states = append_fp8_kv_cache(k_cache, v_cache, key_states, value_states = append_fp8_kv_cache(k_cache, v_cache,
key_states, value_states) key_states, value_states)
@ -307,7 +307,7 @@ def mistral_attention_forward_quantized(
else: else:
k_cache, v_cache = past_key_value k_cache, v_cache = past_key_value
key_states, value_states = append_fp8_kv_cache(k_cache, v_cache, key_states, value_states = append_fp8_kv_cache(k_cache, v_cache,
key_states, value_states, new_layout=True) key_states, value_states)
kv_seq_len = key_states.shape[-2] kv_seq_len = key_states.shape[-2]
past_key_value = (key_states, value_states) past_key_value = (key_states, value_states)
@ -680,13 +680,11 @@ def mistral_attention_forward_4_36_quantized(
if use_cache: if use_cache:
cache_kwargs = None cache_kwargs = None
key_states, value_states = past_key_value.update(key_states, value_states, key_states, value_states = past_key_value.update(key_states, value_states,
self.layer_idx, cache_kwargs, self.layer_idx, cache_kwargs)
new_layout=True)
else: else:
cache_kwargs = None # Specific to RoPE models cache_kwargs = None # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, key_states, value_states = past_key_value.update(key_states, value_states,
self.layer_idx, cache_kwargs, self.layer_idx, cache_kwargs)
new_layout=True)
kv_seq_len = key_states.shape[-2] kv_seq_len = key_states.shape[-2]
if not use_sdp_fp8(q_len, key_states.shape[2], query_states): if not use_sdp_fp8(q_len, key_states.shape[2], query_states):
key_states, value_states = restore_fp8_kv_cache(key_states, value_states, key_states, value_states = restore_fp8_kv_cache(key_states, value_states,

View file

@ -129,7 +129,7 @@ def attention_forward(
invalidInputError(past_key_value is not None, invalidInputError(past_key_value is not None,
"`past_key_value` cannot be None") "`past_key_value` cannot be None")
key_states, value_states = past_key_value.update(key_states, value_states, key_states, value_states = past_key_value.update(key_states, value_states,
self.layer_idx, None, new_layout=True) self.layer_idx, None)
key_states = repeat_kv(key_states, self.num_key_value_groups) key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups)

View file

@ -446,7 +446,7 @@ def qwen_attention_forward_quantized(
max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
k_cache, v_cache = init_fp8_kv_cache( k_cache, v_cache = init_fp8_kv_cache(
query.size(0), self.num_heads, kv_seq_len, self.head_dim, query.size(0), self.num_heads, kv_seq_len, self.head_dim,
device=query.device, new_layout=True device=query.device
) )
key, value = append_fp8_kv_cache(k_cache, v_cache, key, value) key, value = append_fp8_kv_cache(k_cache, v_cache, key, value)
else: else:
@ -461,7 +461,7 @@ def qwen_attention_forward_quantized(
v_cache = v_cache.transpose(1, 2) v_cache = v_cache.transpose(1, 2)
# k_cache and v_cache's shape: [bs, num_heads, context_length, head_dim] # k_cache and v_cache's shape: [bs, num_heads, context_length, head_dim]
key, value = append_fp8_kv_cache(k_cache, v_cache, key, value, new_layout=True) key, value = append_fp8_kv_cache(k_cache, v_cache, key, value)
attn_output, attn_weight = core_attn( attn_output, attn_weight = core_attn(
self, query, key, value, causal_mask, attention_mask, head_mask self, query, key, value, causal_mask, attention_mask, head_mask

View file

@ -358,8 +358,7 @@ def qwen2_attention_forward_quantized(
if past_key_value is not None: if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, key_states, value_states = past_key_value.update(key_states, value_states,
self.layer_idx, cache_kwargs, self.layer_idx, cache_kwargs)
new_layout=True)
if q_len == 1 and query_states.device.type == 'xpu' and not self.training \ if q_len == 1 and query_states.device.type == 'xpu' and not self.training \
and not hidden_states.requires_grad: and not hidden_states.requires_grad:

View file

@ -132,7 +132,7 @@ def attention_forward(
use_quantize_kv = use_quantize_kv_cache(self.o_proj, hidden_states) use_quantize_kv = use_quantize_kv_cache(self.o_proj, hidden_states)
key_states, value_states = past_key_value.update(key_states, value_states, key_states, value_states = past_key_value.update(key_states, value_states,
self.layer_idx, None, new_layout=True) self.layer_idx, None)
if use_quantize_kv and q_len == 1: if use_quantize_kv and q_len == 1:
import linear_q4_0 import linear_q4_0

View file

@ -96,7 +96,7 @@ def kv_cache_device_check(x: torch.Tensor) -> bool:
1 < x.size(0) and x.size(0) <= 8) 1 < x.size(0) and x.size(0) <= 8)
def init_fp8_kv_cache(batch_size, num_heads, current_length, head_dim, device, new_layout=False): def init_fp8_kv_cache(batch_size, num_heads, current_length, head_dim, device):
max_length = current_length + FP8_KV_ALLOC_LENGTH max_length = current_length + FP8_KV_ALLOC_LENGTH
k_cache_storage = torch.empty(batch_size, num_heads, max_length, head_dim, k_cache_storage = torch.empty(batch_size, num_heads, max_length, head_dim,
@ -104,7 +104,6 @@ def init_fp8_kv_cache(batch_size, num_heads, current_length, head_dim, device, n
k_cache = k_cache_storage.as_strided((batch_size, num_heads, 0, head_dim), k_cache = k_cache_storage.as_strided((batch_size, num_heads, 0, head_dim),
k_cache_storage.stride(), storage_offset=0) k_cache_storage.stride(), storage_offset=0)
# ignore `new_layout`, will remove it in next PR
v_cache_storage = torch.empty(batch_size, num_heads, max_length, head_dim, v_cache_storage = torch.empty(batch_size, num_heads, max_length, head_dim,
dtype=torch.uint8, device=device) dtype=torch.uint8, device=device)
v_cache = v_cache_storage.as_strided((batch_size, num_heads, 0, head_dim), v_cache = v_cache_storage.as_strided((batch_size, num_heads, 0, head_dim),
@ -112,14 +111,14 @@ def init_fp8_kv_cache(batch_size, num_heads, current_length, head_dim, device, n
return k_cache, v_cache return k_cache, v_cache
def append_fp8_kv_cache(k_cache, v_cache, key, value, new_layout=False): def append_fp8_kv_cache(k_cache, v_cache, key, value):
batch_size, num_heads, cur_length, head_dim = k_cache.shape batch_size, num_heads, cur_length, head_dim = k_cache.shape
new_length = cur_length + key.size(2) new_length = cur_length + key.size(2)
new_size = (batch_size, num_heads, new_length, head_dim) new_size = (batch_size, num_heads, new_length, head_dim)
if k_cache.stride(1) < new_length * k_cache.size(3): if k_cache.stride(1) < new_length * k_cache.size(3):
new_k_cache, new_v_cache = init_fp8_kv_cache(batch_size, num_heads, new_length, new_k_cache, new_v_cache = init_fp8_kv_cache(batch_size, num_heads, new_length,
head_dim, key.device, new_layout) head_dim, key.device)
new_k_cache = new_k_cache.as_strided(new_size, new_k_cache.stride(), storage_offset=0) new_k_cache = new_k_cache.as_strided(new_size, new_k_cache.stride(), storage_offset=0)
new_v_cache = new_v_cache.as_strided(new_size, new_v_cache.stride(), storage_offset=0) new_v_cache = new_v_cache.as_strided(new_size, new_v_cache.stride(), storage_offset=0)
new_k_cache[:, :, :cur_length, :] = k_cache new_k_cache[:, :, :cur_length, :] = k_cache