qwen2 fp8 cache (#10446)

* qwen2 fp8 cache

* fix style check
This commit is contained in:
Xin Qiu 2024-03-19 08:32:39 +08:00 committed by GitHub
parent 9e763b049c
commit bbd749dceb
2 changed files with 30 additions and 28 deletions

View file

@ -29,6 +29,7 @@ class DynamicFp8Cache(DynamicCache):
value_states: torch.Tensor, value_states: torch.Tensor,
layer_idx: int, layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]]=None, cache_kwargs: Optional[Dict[str, Any]]=None,
new_layout=False,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size, num_heads, seq_len, head_dim = key_states.shape batch_size, num_heads, seq_len, head_dim = key_states.shape
@ -41,15 +42,18 @@ class DynamicFp8Cache(DynamicCache):
k_cache, v_cache = init_fp8_kv_cache( k_cache, v_cache = init_fp8_kv_cache(
batch_size, num_heads, seq_len, head_dim, batch_size, num_heads, seq_len, head_dim,
device=key_states.device, device=key_states.device,
new_layout=new_layout,
) )
k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache, key_states, value_states) k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache, key_states, value_states,
new_layout=new_layout)
self.key_cache.append(k_cache) self.key_cache.append(k_cache)
self.value_cache.append(v_cache) self.value_cache.append(v_cache)
else: else:
k_cache = self.key_cache[layer_idx] k_cache = self.key_cache[layer_idx]
v_cache = self.value_cache[layer_idx] v_cache = self.value_cache[layer_idx]
k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache, key_states, value_states) k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache, key_states, value_states,
new_layout=new_layout)
self.key_cache[layer_idx] = k_cache self.key_cache[layer_idx] = k_cache
self.value_cache[layer_idx] = v_cache self.value_cache[layer_idx] = v_cache

View file

@ -354,43 +354,41 @@ def qwen2_attention_forward_quantized(
if past_key_value is not None: if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, key_states, value_states = past_key_value.update(key_states, value_states,
self.layer_idx, cache_kwargs) self.layer_idx, cache_kwargs,
new_layout=True)
if q_len != 1: if q_len == 1 and query_states.device.type == 'xpu' and not self.training \
and not hidden_states.requires_grad:
import linear_q4_0
attn_output = linear_q4_0.sdp_fp8(query_states, key_states, value_states,
attention_mask)
attn_weights = None
else:
key, value = restore_fp8_kv_cache(key_states, value_states, query_states.dtype) key, value = restore_fp8_kv_cache(key_states, value_states, query_states.dtype)
key = repeat_kv(key, self.num_key_value_groups) key = repeat_kv(key, self.num_key_value_groups)
value = repeat_kv(value, self.num_key_value_groups) value = repeat_kv(value, self.num_key_value_groups)
attn_weights = torch.matmul(query_states, key.transpose(2, 3)) attn_weights = torch.matmul(query_states, key.transpose(2, 3))
else: attn_weights = attn_weights / math.sqrt(self.head_dim)
import linear_q4_0
attn_weights = linear_q4_0.query_key_fp8_matmul(query_states, key_states)
attn_weights = attn_weights / math.sqrt(self.head_dim) invalidInputError(attn_weights.size() == (bsz, self.num_heads, q_len, kv_seq_len),
("Attention weights should be of size "
f"{(bsz, self.num_heads, q_len, kv_seq_len)},"
"but is {attn_weights.size()}"))
invalidInputError(attn_weights.size() == (bsz, self.num_heads, q_len, kv_seq_len), if attention_mask is not None:
("Attention weights should be of size " invalidInputError(attention_mask.size() == (bsz, 1, q_len, kv_seq_len),
f"{(bsz, self.num_heads, q_len, kv_seq_len)}," (f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}"
"but is {attn_weights.size()}")) f" but is {attention_mask.size()}"))
if attention_mask is not None: attn_weights = attn_weights + attention_mask
invalidInputError(attention_mask.size() == (bsz, 1, q_len, kv_seq_len),
(f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}"
f" but is {attention_mask.size()}"))
attn_weights = attn_weights + attention_mask # upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1,
dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout,
training=self.training)
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1,
dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout,
training=self.training)
if q_len != 1:
attn_output = torch.matmul(attn_weights, value) attn_output = torch.matmul(attn_weights, value)
else:
import linear_q4_0
attn_output = linear_q4_0.attn_value_fp8_matmul(attn_weights,
value_states.transpose(-1, -2))
invalidInputError(attn_output.size() == (bsz, self.num_heads, q_len, self.head_dim), invalidInputError(attn_output.size() == (bsz, self.num_heads, q_len, self.head_dim),
"`attn_output` should be of size " "`attn_output` should be of size "