diff --git a/python/llm/src/bigdl/llm/transformers/kv.py b/python/llm/src/bigdl/llm/transformers/kv.py index 0d3ad897..71aa6c9f 100644 --- a/python/llm/src/bigdl/llm/transformers/kv.py +++ b/python/llm/src/bigdl/llm/transformers/kv.py @@ -29,6 +29,7 @@ class DynamicFp8Cache(DynamicCache): value_states: torch.Tensor, layer_idx: int, cache_kwargs: Optional[Dict[str, Any]]=None, + new_layout=False, ) -> Tuple[torch.Tensor, torch.Tensor]: batch_size, num_heads, seq_len, head_dim = key_states.shape @@ -41,15 +42,18 @@ class DynamicFp8Cache(DynamicCache): k_cache, v_cache = init_fp8_kv_cache( batch_size, num_heads, seq_len, head_dim, device=key_states.device, + new_layout=new_layout, ) - k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache, key_states, value_states) + k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache, key_states, value_states, + new_layout=new_layout) self.key_cache.append(k_cache) self.value_cache.append(v_cache) else: k_cache = self.key_cache[layer_idx] v_cache = self.value_cache[layer_idx] - k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache, key_states, value_states) + k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache, key_states, value_states, + new_layout=new_layout) self.key_cache[layer_idx] = k_cache self.value_cache[layer_idx] = v_cache diff --git a/python/llm/src/bigdl/llm/transformers/models/qwen2.py b/python/llm/src/bigdl/llm/transformers/models/qwen2.py index f9fbf0d6..81864db0 100644 --- a/python/llm/src/bigdl/llm/transformers/models/qwen2.py +++ b/python/llm/src/bigdl/llm/transformers/models/qwen2.py @@ -354,43 +354,41 @@ def qwen2_attention_forward_quantized( if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models key_states, value_states = past_key_value.update(key_states, value_states, - self.layer_idx, cache_kwargs) + self.layer_idx, cache_kwargs, + new_layout=True) - if q_len != 1: + if q_len == 1 and query_states.device.type == 'xpu' and not self.training \ + and not hidden_states.requires_grad: + import linear_q4_0 + attn_output = linear_q4_0.sdp_fp8(query_states, key_states, value_states, + attention_mask) + attn_weights = None + else: key, value = restore_fp8_kv_cache(key_states, value_states, query_states.dtype) key = repeat_kv(key, self.num_key_value_groups) value = repeat_kv(value, self.num_key_value_groups) attn_weights = torch.matmul(query_states, key.transpose(2, 3)) - else: - import linear_q4_0 - attn_weights = linear_q4_0.query_key_fp8_matmul(query_states, key_states) + attn_weights = attn_weights / math.sqrt(self.head_dim) - attn_weights = attn_weights / math.sqrt(self.head_dim) + invalidInputError(attn_weights.size() == (bsz, self.num_heads, q_len, kv_seq_len), + ("Attention weights should be of size " + f"{(bsz, self.num_heads, q_len, kv_seq_len)}," + "but is {attn_weights.size()}")) - invalidInputError(attn_weights.size() == (bsz, self.num_heads, q_len, kv_seq_len), - ("Attention weights should be of size " - f"{(bsz, self.num_heads, q_len, kv_seq_len)}," - "but is {attn_weights.size()}")) + if attention_mask is not None: + invalidInputError(attention_mask.size() == (bsz, 1, q_len, kv_seq_len), + (f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}" + f" but is {attention_mask.size()}")) - if attention_mask is not None: - invalidInputError(attention_mask.size() == (bsz, 1, q_len, kv_seq_len), - (f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}" - f" but is {attention_mask.size()}")) + attn_weights = attn_weights + attention_mask - attn_weights = attn_weights + attention_mask + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, + dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, + training=self.training) - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, - dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, - training=self.training) - - if q_len != 1: attn_output = torch.matmul(attn_weights, value) - else: - import linear_q4_0 - attn_output = linear_q4_0.attn_value_fp8_matmul(attn_weights, - value_states.transpose(-1, -2)) invalidInputError(attn_output.size() == (bsz, self.num_heads, q_len, self.head_dim), "`attn_output` should be of size "