remove new_layout parameter (#10906)
This commit is contained in:
parent
fbcd7bc737
commit
d884c62dc4
11 changed files with 25 additions and 38 deletions
|
|
@ -32,7 +32,6 @@ class DynamicFp8Cache(DynamicCache):
|
||||||
value_states: torch.Tensor,
|
value_states: torch.Tensor,
|
||||||
layer_idx: int,
|
layer_idx: int,
|
||||||
cache_kwargs: Optional[Dict[str, Any]]=None,
|
cache_kwargs: Optional[Dict[str, Any]]=None,
|
||||||
new_layout=False,
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
|
||||||
batch_size, num_heads, seq_len, head_dim = key_states.shape
|
batch_size, num_heads, seq_len, head_dim = key_states.shape
|
||||||
|
|
@ -50,18 +49,15 @@ class DynamicFp8Cache(DynamicCache):
|
||||||
k_cache, v_cache = init_fp8_kv_cache(
|
k_cache, v_cache = init_fp8_kv_cache(
|
||||||
batch_size, num_heads, seq_len, head_dim,
|
batch_size, num_heads, seq_len, head_dim,
|
||||||
device=key_states.device,
|
device=key_states.device,
|
||||||
new_layout=new_layout,
|
|
||||||
)
|
)
|
||||||
k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache, key_states, value_states,
|
k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache, key_states, value_states)
|
||||||
new_layout=new_layout)
|
|
||||||
|
|
||||||
self.key_cache.append(k_cache)
|
self.key_cache.append(k_cache)
|
||||||
self.value_cache.append(v_cache)
|
self.value_cache.append(v_cache)
|
||||||
else:
|
else:
|
||||||
k_cache = self.key_cache[layer_idx]
|
k_cache = self.key_cache[layer_idx]
|
||||||
v_cache = self.value_cache[layer_idx]
|
v_cache = self.value_cache[layer_idx]
|
||||||
k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache, key_states, value_states,
|
k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache, key_states, value_states)
|
||||||
new_layout=new_layout)
|
|
||||||
self.key_cache[layer_idx] = k_cache
|
self.key_cache[layer_idx] = k_cache
|
||||||
self.value_cache[layer_idx] = v_cache
|
self.value_cache[layer_idx] = v_cache
|
||||||
|
|
||||||
|
|
@ -77,7 +73,6 @@ class DynamicNormalCache(DynamicCache):
|
||||||
value_states: torch.Tensor,
|
value_states: torch.Tensor,
|
||||||
layer_idx: int,
|
layer_idx: int,
|
||||||
cache_kwargs: Optional[Dict[str, Any]]=None,
|
cache_kwargs: Optional[Dict[str, Any]]=None,
|
||||||
new_layout=False, # useless, just keep same as DynamicFp8Cache
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
|
||||||
batch_size, num_heads, seq_len, head_dim = key_states.shape
|
batch_size, num_heads, seq_len, head_dim = key_states.shape
|
||||||
|
|
|
||||||
|
|
@ -128,15 +128,15 @@ def baichuan_attention_forward_7b_quantized(
|
||||||
if use_cache:
|
if use_cache:
|
||||||
k_cache, v_cache = init_fp8_kv_cache(
|
k_cache, v_cache = init_fp8_kv_cache(
|
||||||
bsz, self.num_heads, kv_seq_len, self.head_dim,
|
bsz, self.num_heads, kv_seq_len, self.head_dim,
|
||||||
device=device, new_layout=True
|
device=device
|
||||||
)
|
)
|
||||||
key_states, value_states = append_fp8_kv_cache(k_cache, v_cache, key_states,
|
key_states, value_states = append_fp8_kv_cache(k_cache, v_cache, key_states,
|
||||||
value_states, new_layout=True)
|
value_states)
|
||||||
past_key_value = (key_states, value_states)
|
past_key_value = (key_states, value_states)
|
||||||
else:
|
else:
|
||||||
k_cache, v_cache = past_key_value
|
k_cache, v_cache = past_key_value
|
||||||
key_states, value_states = append_fp8_kv_cache(k_cache, v_cache,
|
key_states, value_states = append_fp8_kv_cache(k_cache, v_cache,
|
||||||
key_states, value_states, new_layout=True)
|
key_states, value_states)
|
||||||
kv_seq_len = key_states.shape[-2]
|
kv_seq_len = key_states.shape[-2]
|
||||||
past_key_value = (key_states, value_states)
|
past_key_value = (key_states, value_states)
|
||||||
if query_states.size(2) != 1 or query_states.device.type != 'xpu':
|
if query_states.size(2) != 1 or query_states.device.type != 'xpu':
|
||||||
|
|
|
||||||
|
|
@ -142,12 +142,12 @@ def baichuan_attention_forward_7b_quantized(
|
||||||
kv_seq_len = key_states.shape[-2]
|
kv_seq_len = key_states.shape[-2]
|
||||||
k_cache, v_cache = init_fp8_kv_cache(
|
k_cache, v_cache = init_fp8_kv_cache(
|
||||||
bsz, self.num_heads, kv_seq_len, self.head_dim,
|
bsz, self.num_heads, kv_seq_len, self.head_dim,
|
||||||
device=device, new_layout=True
|
device=device
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
k_cache, v_cache = past_key_value
|
k_cache, v_cache = past_key_value
|
||||||
key_states, value_states = append_fp8_kv_cache(k_cache, v_cache,
|
key_states, value_states = append_fp8_kv_cache(k_cache, v_cache,
|
||||||
key_states, value_states, new_layout=True)
|
key_states, value_states)
|
||||||
|
|
||||||
past_key_value = (key_states, value_states) if use_cache else None
|
past_key_value = (key_states, value_states) if use_cache else None
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -280,8 +280,7 @@ def chatglm2_quantized_attention_forward_8eb45c(
|
||||||
n_kv_head,
|
n_kv_head,
|
||||||
seq_len,
|
seq_len,
|
||||||
head_dim,
|
head_dim,
|
||||||
query_layer.device,
|
query_layer.device)
|
||||||
new_layout=True)
|
|
||||||
k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache, key_layer, value_layer)
|
k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache, key_layer, value_layer)
|
||||||
else:
|
else:
|
||||||
k_cache, v_cache = kv_cache
|
k_cache, v_cache = kv_cache
|
||||||
|
|
@ -289,8 +288,7 @@ def chatglm2_quantized_attention_forward_8eb45c(
|
||||||
v_cache = v_cache.permute(1, 2, 0, 3)
|
v_cache = v_cache.permute(1, 2, 0, 3)
|
||||||
# k_cache, v_cache's shape: [bs, n_kv_head, seq_len, head_dim]
|
# k_cache, v_cache's shape: [bs, n_kv_head, seq_len, head_dim]
|
||||||
|
|
||||||
k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache, key_layer, value_layer,
|
k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache, key_layer, value_layer)
|
||||||
new_layout=True)
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
attention_mask = ~attention_mask
|
attention_mask = ~attention_mask
|
||||||
|
|
|
||||||
|
|
@ -438,7 +438,7 @@ def llama_attention_forward_4_31_quantized(
|
||||||
if use_cache:
|
if use_cache:
|
||||||
k_cache, v_cache = init_fp8_kv_cache(
|
k_cache, v_cache = init_fp8_kv_cache(
|
||||||
bsz, self.num_key_value_heads, kv_seq_len, self.head_dim,
|
bsz, self.num_key_value_heads, kv_seq_len, self.head_dim,
|
||||||
device=query_states.device, new_layout=True
|
device=query_states.device
|
||||||
)
|
)
|
||||||
key_states, value_states = append_fp8_kv_cache(k_cache, v_cache,
|
key_states, value_states = append_fp8_kv_cache(k_cache, v_cache,
|
||||||
key_states, value_states)
|
key_states, value_states)
|
||||||
|
|
@ -446,7 +446,7 @@ def llama_attention_forward_4_31_quantized(
|
||||||
else:
|
else:
|
||||||
k_cache, v_cache = past_key_value
|
k_cache, v_cache = past_key_value
|
||||||
key_states, value_states = append_fp8_kv_cache(k_cache, v_cache,
|
key_states, value_states = append_fp8_kv_cache(k_cache, v_cache,
|
||||||
key_states, value_states, new_layout=True)
|
key_states, value_states)
|
||||||
kv_seq_len = key_states.shape[-2]
|
kv_seq_len = key_states.shape[-2]
|
||||||
past_key_value = (key_states, value_states)
|
past_key_value = (key_states, value_states)
|
||||||
|
|
||||||
|
|
@ -1067,13 +1067,11 @@ def llama_attention_forward_4_36_quantized(
|
||||||
if use_cache:
|
if use_cache:
|
||||||
cache_kwargs = None
|
cache_kwargs = None
|
||||||
key_states, value_states = past_key_value.update(key_states, value_states,
|
key_states, value_states = past_key_value.update(key_states, value_states,
|
||||||
self.layer_idx, cache_kwargs,
|
self.layer_idx, cache_kwargs)
|
||||||
new_layout=True)
|
|
||||||
else:
|
else:
|
||||||
cache_kwargs = None # Specific to RoPE models
|
cache_kwargs = None # Specific to RoPE models
|
||||||
key_states, value_states = past_key_value.update(key_states, value_states,
|
key_states, value_states = past_key_value.update(key_states, value_states,
|
||||||
self.layer_idx, cache_kwargs,
|
self.layer_idx, cache_kwargs)
|
||||||
new_layout=True)
|
|
||||||
kv_seq_len = key_states.shape[-2]
|
kv_seq_len = key_states.shape[-2]
|
||||||
if not use_sdp_fp8(q_len, key_states.shape[2], query_states):
|
if not use_sdp_fp8(q_len, key_states.shape[2], query_states):
|
||||||
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
|
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
|
||||||
|
|
|
||||||
|
|
@ -299,7 +299,7 @@ def mistral_attention_forward_quantized(
|
||||||
if use_cache:
|
if use_cache:
|
||||||
k_cache, v_cache = init_fp8_kv_cache(
|
k_cache, v_cache = init_fp8_kv_cache(
|
||||||
bsz, self.num_heads, kv_seq_len, self.head_dim,
|
bsz, self.num_heads, kv_seq_len, self.head_dim,
|
||||||
device=query_states.device, new_layout=True
|
device=query_states.device
|
||||||
)
|
)
|
||||||
key_states, value_states = append_fp8_kv_cache(k_cache, v_cache,
|
key_states, value_states = append_fp8_kv_cache(k_cache, v_cache,
|
||||||
key_states, value_states)
|
key_states, value_states)
|
||||||
|
|
@ -307,7 +307,7 @@ def mistral_attention_forward_quantized(
|
||||||
else:
|
else:
|
||||||
k_cache, v_cache = past_key_value
|
k_cache, v_cache = past_key_value
|
||||||
key_states, value_states = append_fp8_kv_cache(k_cache, v_cache,
|
key_states, value_states = append_fp8_kv_cache(k_cache, v_cache,
|
||||||
key_states, value_states, new_layout=True)
|
key_states, value_states)
|
||||||
kv_seq_len = key_states.shape[-2]
|
kv_seq_len = key_states.shape[-2]
|
||||||
past_key_value = (key_states, value_states)
|
past_key_value = (key_states, value_states)
|
||||||
|
|
||||||
|
|
@ -680,13 +680,11 @@ def mistral_attention_forward_4_36_quantized(
|
||||||
if use_cache:
|
if use_cache:
|
||||||
cache_kwargs = None
|
cache_kwargs = None
|
||||||
key_states, value_states = past_key_value.update(key_states, value_states,
|
key_states, value_states = past_key_value.update(key_states, value_states,
|
||||||
self.layer_idx, cache_kwargs,
|
self.layer_idx, cache_kwargs)
|
||||||
new_layout=True)
|
|
||||||
else:
|
else:
|
||||||
cache_kwargs = None # Specific to RoPE models
|
cache_kwargs = None # Specific to RoPE models
|
||||||
key_states, value_states = past_key_value.update(key_states, value_states,
|
key_states, value_states = past_key_value.update(key_states, value_states,
|
||||||
self.layer_idx, cache_kwargs,
|
self.layer_idx, cache_kwargs)
|
||||||
new_layout=True)
|
|
||||||
kv_seq_len = key_states.shape[-2]
|
kv_seq_len = key_states.shape[-2]
|
||||||
if not use_sdp_fp8(q_len, key_states.shape[2], query_states):
|
if not use_sdp_fp8(q_len, key_states.shape[2], query_states):
|
||||||
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
|
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
|
||||||
|
|
|
||||||
|
|
@ -129,7 +129,7 @@ def attention_forward(
|
||||||
invalidInputError(past_key_value is not None,
|
invalidInputError(past_key_value is not None,
|
||||||
"`past_key_value` cannot be None")
|
"`past_key_value` cannot be None")
|
||||||
key_states, value_states = past_key_value.update(key_states, value_states,
|
key_states, value_states = past_key_value.update(key_states, value_states,
|
||||||
self.layer_idx, None, new_layout=True)
|
self.layer_idx, None)
|
||||||
|
|
||||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
|
|
|
||||||
|
|
@ -446,7 +446,7 @@ def qwen_attention_forward_quantized(
|
||||||
max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
|
max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
|
||||||
k_cache, v_cache = init_fp8_kv_cache(
|
k_cache, v_cache = init_fp8_kv_cache(
|
||||||
query.size(0), self.num_heads, kv_seq_len, self.head_dim,
|
query.size(0), self.num_heads, kv_seq_len, self.head_dim,
|
||||||
device=query.device, new_layout=True
|
device=query.device
|
||||||
)
|
)
|
||||||
key, value = append_fp8_kv_cache(k_cache, v_cache, key, value)
|
key, value = append_fp8_kv_cache(k_cache, v_cache, key, value)
|
||||||
else:
|
else:
|
||||||
|
|
@ -461,7 +461,7 @@ def qwen_attention_forward_quantized(
|
||||||
v_cache = v_cache.transpose(1, 2)
|
v_cache = v_cache.transpose(1, 2)
|
||||||
# k_cache and v_cache's shape: [bs, num_heads, context_length, head_dim]
|
# k_cache and v_cache's shape: [bs, num_heads, context_length, head_dim]
|
||||||
|
|
||||||
key, value = append_fp8_kv_cache(k_cache, v_cache, key, value, new_layout=True)
|
key, value = append_fp8_kv_cache(k_cache, v_cache, key, value)
|
||||||
|
|
||||||
attn_output, attn_weight = core_attn(
|
attn_output, attn_weight = core_attn(
|
||||||
self, query, key, value, causal_mask, attention_mask, head_mask
|
self, query, key, value, causal_mask, attention_mask, head_mask
|
||||||
|
|
|
||||||
|
|
@ -358,8 +358,7 @@ def qwen2_attention_forward_quantized(
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
||||||
key_states, value_states = past_key_value.update(key_states, value_states,
|
key_states, value_states = past_key_value.update(key_states, value_states,
|
||||||
self.layer_idx, cache_kwargs,
|
self.layer_idx, cache_kwargs)
|
||||||
new_layout=True)
|
|
||||||
|
|
||||||
if q_len == 1 and query_states.device.type == 'xpu' and not self.training \
|
if q_len == 1 and query_states.device.type == 'xpu' and not self.training \
|
||||||
and not hidden_states.requires_grad:
|
and not hidden_states.requires_grad:
|
||||||
|
|
|
||||||
|
|
@ -132,7 +132,7 @@ def attention_forward(
|
||||||
use_quantize_kv = use_quantize_kv_cache(self.o_proj, hidden_states)
|
use_quantize_kv = use_quantize_kv_cache(self.o_proj, hidden_states)
|
||||||
|
|
||||||
key_states, value_states = past_key_value.update(key_states, value_states,
|
key_states, value_states = past_key_value.update(key_states, value_states,
|
||||||
self.layer_idx, None, new_layout=True)
|
self.layer_idx, None)
|
||||||
|
|
||||||
if use_quantize_kv and q_len == 1:
|
if use_quantize_kv and q_len == 1:
|
||||||
import linear_q4_0
|
import linear_q4_0
|
||||||
|
|
|
||||||
|
|
@ -96,7 +96,7 @@ def kv_cache_device_check(x: torch.Tensor) -> bool:
|
||||||
1 < x.size(0) and x.size(0) <= 8)
|
1 < x.size(0) and x.size(0) <= 8)
|
||||||
|
|
||||||
|
|
||||||
def init_fp8_kv_cache(batch_size, num_heads, current_length, head_dim, device, new_layout=False):
|
def init_fp8_kv_cache(batch_size, num_heads, current_length, head_dim, device):
|
||||||
max_length = current_length + FP8_KV_ALLOC_LENGTH
|
max_length = current_length + FP8_KV_ALLOC_LENGTH
|
||||||
|
|
||||||
k_cache_storage = torch.empty(batch_size, num_heads, max_length, head_dim,
|
k_cache_storage = torch.empty(batch_size, num_heads, max_length, head_dim,
|
||||||
|
|
@ -104,7 +104,6 @@ def init_fp8_kv_cache(batch_size, num_heads, current_length, head_dim, device, n
|
||||||
k_cache = k_cache_storage.as_strided((batch_size, num_heads, 0, head_dim),
|
k_cache = k_cache_storage.as_strided((batch_size, num_heads, 0, head_dim),
|
||||||
k_cache_storage.stride(), storage_offset=0)
|
k_cache_storage.stride(), storage_offset=0)
|
||||||
|
|
||||||
# ignore `new_layout`, will remove it in next PR
|
|
||||||
v_cache_storage = torch.empty(batch_size, num_heads, max_length, head_dim,
|
v_cache_storage = torch.empty(batch_size, num_heads, max_length, head_dim,
|
||||||
dtype=torch.uint8, device=device)
|
dtype=torch.uint8, device=device)
|
||||||
v_cache = v_cache_storage.as_strided((batch_size, num_heads, 0, head_dim),
|
v_cache = v_cache_storage.as_strided((batch_size, num_heads, 0, head_dim),
|
||||||
|
|
@ -112,14 +111,14 @@ def init_fp8_kv_cache(batch_size, num_heads, current_length, head_dim, device, n
|
||||||
return k_cache, v_cache
|
return k_cache, v_cache
|
||||||
|
|
||||||
|
|
||||||
def append_fp8_kv_cache(k_cache, v_cache, key, value, new_layout=False):
|
def append_fp8_kv_cache(k_cache, v_cache, key, value):
|
||||||
batch_size, num_heads, cur_length, head_dim = k_cache.shape
|
batch_size, num_heads, cur_length, head_dim = k_cache.shape
|
||||||
new_length = cur_length + key.size(2)
|
new_length = cur_length + key.size(2)
|
||||||
new_size = (batch_size, num_heads, new_length, head_dim)
|
new_size = (batch_size, num_heads, new_length, head_dim)
|
||||||
|
|
||||||
if k_cache.stride(1) < new_length * k_cache.size(3):
|
if k_cache.stride(1) < new_length * k_cache.size(3):
|
||||||
new_k_cache, new_v_cache = init_fp8_kv_cache(batch_size, num_heads, new_length,
|
new_k_cache, new_v_cache = init_fp8_kv_cache(batch_size, num_heads, new_length,
|
||||||
head_dim, key.device, new_layout)
|
head_dim, key.device)
|
||||||
new_k_cache = new_k_cache.as_strided(new_size, new_k_cache.stride(), storage_offset=0)
|
new_k_cache = new_k_cache.as_strided(new_size, new_k_cache.stride(), storage_offset=0)
|
||||||
new_v_cache = new_v_cache.as_strided(new_size, new_v_cache.stride(), storage_offset=0)
|
new_v_cache = new_v_cache.as_strided(new_size, new_v_cache.stride(), storage_offset=0)
|
||||||
new_k_cache[:, :, :cur_length, :] = k_cache
|
new_k_cache[:, :, :cur_length, :] = k_cache
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue