refactor chatglm2, internlm, stablelm and qwen (#12604)

This commit is contained in:
Yishuo Wang 2024-12-24 18:18:00 +08:00 committed by GitHub
parent 073f936c37
commit 4135b895b3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 53 additions and 279 deletions

View file

@ -18,17 +18,16 @@
# #
import os import os
import math
import torch import torch
from typing import Optional, Tuple from typing import Optional, Tuple
from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.modeling_outputs import BaseModelOutputWithPast
from ipex_llm.utils.common.log4Error import invalidInputError from ipex_llm.transformers.models.common import scaled_dot_product_attention
from ipex_llm.transformers.models.utils import restore_fp8_kv_cache, update_past_key_value from ipex_llm.transformers.models.utils import update_past_key_value
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, use_sdp, use_sdp_causal from ipex_llm.transformers.models.utils import use_quantize_kv_cache
from ipex_llm.transformers.models.utils import should_use_fuse_rope, apply_rotary_pos_emb from ipex_llm.transformers.models.utils import should_use_fuse_rope, apply_rotary_pos_emb
from ipex_llm.transformers.models.utils import mlp_fusion_check, SILU from ipex_llm.transformers.models.utils import mlp_fusion_check, SILU
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, use_sdp, \ from ipex_llm.transformers.models.utils import use_quantize_kv_cache
use_sdp_causal, should_use_compresskv, is_enough_kv_cache_room_4_36 from ipex_llm.transformers.models.utils import should_use_compresskv, is_enough_kv_cache_room_4_36
from ipex_llm.transformers.kv import DynamicCompressCache, DynamicCompressFp8Cache from ipex_llm.transformers.kv import DynamicCompressCache, DynamicCompressFp8Cache
KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)) KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))
@ -310,50 +309,10 @@ def chatglm2_attention_forward(
value_states.permute(2, 0, 1, 3)) if use_cache else None value_states.permute(2, 0, 1, 3)) if use_cache else None
# IPEX-LLM OPT: sdp # IPEX-LLM OPT: sdp
attn_weights = None attn_output = scaled_dot_product_attention(
if use_sdp(q_len, kv_seq_len, head_dim, query_states): query_states, key_states, value_states,
import xe_addons attention_mask, q_len == kv_seq_len
if use_compresskv and attention_mask is not None: )
attention_mask = None
if use_quantize_kv:
attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states, attention_mask)
else:
attn_output = xe_addons.sdp(query_states, key_states, value_states, attention_mask)
elif use_sdp_causal(q_len, kv_seq_len, head_dim, query_states, self.training):
import xe_addons
if use_quantize_kv:
attn_output = xe_addons.sdp_fp8_causal(query_states, key_states, value_states,
attention_mask)
else:
attn_output = xe_addons.sdp_causal(query_states, key_states, value_states,
attention_mask)
elif query_states.device.type == "cpu":
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, n_head // n_kv_head)
value_states = repeat_kv(value_states, n_head // n_kv_head)
if q_len == kv_seq_len:
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states, key_states, value_states, is_causal=True
)
else:
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states, key_states, value_states, attention_mask
)
else:
if use_quantize_kv:
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
query_states.dtype)
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, n_head // n_kv_head)
value_states = repeat_kv(value_states, n_head // n_kv_head)
attn_weights = torch.matmul(query_states,
key_states.transpose(2, 3)) / math.sqrt(head_dim)
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1,
dtype=torch.float32).to(value_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
# context_layer's shape: [bsz, n_head, seq_len, head_dim] -> [seq_len, bsz, n_head * head_dim] # context_layer's shape: [bsz, n_head, seq_len, head_dim] -> [seq_len, bsz, n_head * head_dim]
attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(q_len, bsz, n_head * head_dim) attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(q_len, bsz, n_head * head_dim)
@ -541,29 +500,10 @@ def codegeex_attention_forward(
# ================= # =================
# Output. [sq, b, h] # Output. [sq, b, h]
# ================= # =================
context_layer = None context_layer = scaled_dot_product_attention(
if use_sdp(q_len, kv_seq_len, head_dim, query_layer): query_layer, key_layer, value_layer,
import xe_addons attention_mask, q_len == kv_seq_len
context_layer = xe_addons.sdp(query_layer, key_layer, value_layer, attention_mask) )
elif use_sdp_causal(q_len, kv_seq_len, head_dim, query_layer, self.training):
import xe_addons
context_layer = xe_addons.sdp_causal(query_layer, key_layer, value_layer, attention_mask)
else:
# repeat k/v heads if n_kv_heads < n_heads
key_layer = repeat_kv(key_layer, n_head // n_kv_head)
value_layer = repeat_kv(value_layer, n_head // n_kv_head)
if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer,
key_layer,
value_layer,
is_causal=True)
else:
if attention_mask is not None:
attention_mask = ~attention_mask
context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer,
key_layer,
value_layer,
attention_mask)
context_layer = context_layer.permute(2, 0, 1, 3).contiguous().view(q_len, context_layer = context_layer.permute(2, 0, 1, 3).contiguous().view(q_len,
bsz, bsz,

View file

@ -36,18 +36,16 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" PyTorch InternLM model.""" """ PyTorch InternLM model."""
import math
from typing import Optional, Tuple, List from typing import Optional, Tuple, List
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn
from ipex_llm.utils.common.log4Error import invalidInputError from ipex_llm.utils.common.log4Error import invalidInputError
from ipex_llm.transformers.models.common import merge_qkv_base, attention_softmax from ipex_llm.transformers.models.common import merge_qkv_base
from ipex_llm.transformers.models.common import scaled_dot_product_attention
from ipex_llm.transformers.models.utils import should_use_fuse_rope, apply_rotary_pos_emb from ipex_llm.transformers.models.utils import should_use_fuse_rope, apply_rotary_pos_emb
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache from ipex_llm.transformers.models.utils import use_quantize_kv_cache
from ipex_llm.transformers.models.utils import update_past_key_value from ipex_llm.transformers.models.utils import update_past_key_value
from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal
from einops import rearrange from einops import rearrange
@ -98,35 +96,10 @@ def internlm_attention_forward(
# IPEX-LLM OPT: sdp # IPEX-LLM OPT: sdp
attn_weights = None attn_weights = None
if use_sdp(q_len, kv_seq_len, self.head_dim, query_states): attn_output = scaled_dot_product_attention(
import xe_addons query_states, key_states, value_states,
if use_quantize_kv: attention_mask, q_len == kv_seq_len
attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states, )
attention_mask)
else:
attn_output = xe_addons.sdp(query_states, key_states, value_states, attention_mask)
elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training):
import xe_addons
if use_quantize_kv:
attn_output = xe_addons.sdp_fp8_causal(query_states, key_states,
value_states, attention_mask)
else:
attn_output = xe_addons.sdp_causal(query_states, key_states,
value_states, attention_mask)
else:
if use_quantize_kv:
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
query_states.dtype)
attn_weights = torch.matmul(query_states,
key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = attention_softmax(attn_weights)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2) attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
@ -207,38 +180,10 @@ def internlm2_attention_forward(
# IPEX-LLM OPT: sdp # IPEX-LLM OPT: sdp
attn_weights = None attn_weights = None
if use_sdp(q_len, kv_seq_len, self.head_dim, query_states): attn_output = scaled_dot_product_attention(
import xe_addons query_states, key_states, value_states,
if use_quantize_kv: attention_mask, q_len == kv_seq_len
attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states, )
attention_mask)
else:
attn_output = xe_addons.sdp(query_states, key_states, value_states, attention_mask)
elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training):
import xe_addons
if use_quantize_kv:
attn_output = xe_addons.sdp_fp8_causal(query_states, key_states,
value_states, attention_mask)
else:
attn_output = xe_addons.sdp_causal(query_states, key_states,
value_states, attention_mask)
else:
if use_quantize_kv:
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
query_states.dtype)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states,
key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights,
dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
@ -409,38 +354,11 @@ def internlm_xcomposser2_attention_forward(
past_key_value = (key_states, value_states) if use_cache else None past_key_value = (key_states, value_states) if use_cache else None
# IPEX-LLM OPT: sdp # IPEX-LLM OPT: sdp
if use_sdp(q_len, kv_seq_len, self.head_dim, query_states): attn_weights = None
import xe_addons attn_output = scaled_dot_product_attention(
if use_quantize_kv: query_states, key_states, value_states,
attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states, attention_mask, q_len == kv_seq_len
attention_mask) )
else:
attn_output = xe_addons.sdp(query_states, key_states, value_states, attention_mask)
elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training):
import xe_addons
if use_quantize_kv:
attn_output = xe_addons.sdp_fp8_causal(query_states, key_states,
value_states, attention_mask)
else:
attn_output = xe_addons.sdp_causal(query_states, key_states,
value_states, attention_mask)
else:
if use_quantize_kv:
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
query_states.dtype)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states,
key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights,
dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

View file

@ -22,19 +22,19 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
# #
import math
from typing import Optional, Tuple, Union, Callable, List from typing import Optional, Tuple, Union, Callable, List
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import torch.utils.checkpoint import torch.utils.checkpoint
from transformers.utils import logging from transformers.utils import logging
from ipex_llm.transformers.models.common import scaled_dot_product_attention
from ipex_llm.transformers.models.utils import update_past_key_value, should_use_fuse_rope from ipex_llm.transformers.models.utils import update_past_key_value, should_use_fuse_rope
from ipex_llm.transformers.models.utils import restore_fp8_kv_cache, use_quantize_kv_cache from ipex_llm.transformers.models.utils import use_quantize_kv_cache
from ipex_llm.transformers.models.utils import rotate_half, SILU from ipex_llm.transformers.models.utils import rotate_half, SILU
from ipex_llm.transformers.models.utils import mlp_fusion_check from ipex_llm.transformers.models.utils import mlp_fusion_check
from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp, use_sdp_causal from ipex_llm.transformers.models.utils import use_flash_attention
from ipex_llm.utils.common import invalidInputError, invalidOperationError from ipex_llm.utils.common import invalidInputError
from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.modeling_outputs import BaseModelOutputWithPast
@ -118,20 +118,13 @@ def qwen_attention_forward(
# IPEX-LLM OPT: sdp # IPEX-LLM OPT: sdp
attn_weights = None attn_weights = None
if not self.training and not hidden_states.requires_grad and \ if use_flash_attention(query_states, key_states, attention_mask):
use_flash_attention(query_states, key_states, attention_mask):
attn_output = F.scaled_dot_product_attention(query_states.to(dtype=torch.float16), attn_output = F.scaled_dot_product_attention(query_states.to(dtype=torch.float16),
key_states.to(dtype=torch.float16), key_states.to(dtype=torch.float16),
value_states.to(dtype=torch.float16), value_states.to(dtype=torch.float16),
is_causal=True).to(hidden_states.dtype) is_causal=True).to(hidden_states.dtype)
elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training):
import xe_addons
if use_quantize_kv:
attn_output = xe_addons.sdp_fp8_causal(query_states, key_states, value_states, None)
else:
attn_output = xe_addons.sdp_causal(query_states, key_states, value_states, None)
else: else:
if q_len > 1: if q_len > 1 and q_len != kv_seq_len:
causal_mask = torch.tril( causal_mask = torch.tril(
torch.ones((kv_seq_len, kv_seq_len), dtype=torch.bool, device=query_states.device) torch.ones((kv_seq_len, kv_seq_len), dtype=torch.bool, device=query_states.device)
).view(1, 1, kv_seq_len, kv_seq_len) ).view(1, 1, kv_seq_len, kv_seq_len)
@ -146,29 +139,10 @@ def qwen_attention_forward(
else: else:
attention_mask = None attention_mask = None
if use_sdp(q_len, kv_seq_len, self.head_dim, query_states): attn_output = scaled_dot_product_attention(
import xe_addons query_states, key_states, value_states,
if use_quantize_kv: attention_mask, q_len == kv_seq_len
attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states, )
attention_mask)
else:
attn_output = xe_addons.sdp(query_states, key_states, value_states,
attention_mask)
else:
if use_quantize_kv:
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
query_states.dtype)
attn_weights = torch.matmul(query_states,
key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
if self.softmax_in_fp32:
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1,
dtype=torch.float32).to(
value_states.dtype)
else:
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, self.hidden_size) attn_output = attn_output.view(bsz, q_len, self.hidden_size)
@ -247,20 +221,14 @@ def qwen_attention_forward_registered(
# IPEX-LLM OPT: sdp # IPEX-LLM OPT: sdp
attn_weights = None attn_weights = None
if not self.training and not hidden_states.requires_grad and \
use_flash_attention(query_states, key_states, attention_mask): if use_flash_attention(query_states, key_states, attention_mask):
attn_output = F.scaled_dot_product_attention(query_states.to(dtype=torch.float16), attn_output = F.scaled_dot_product_attention(query_states.to(dtype=torch.float16),
key_states.to(dtype=torch.float16), key_states.to(dtype=torch.float16),
value_states.to(dtype=torch.float16), value_states.to(dtype=torch.float16),
is_causal=True).to(hidden_states.dtype) is_causal=True).to(hidden_states.dtype)
elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training):
import xe_addons
if use_quantize_kv:
attn_output = xe_addons.sdp_fp8_causal(query_states, key_states, value_states, None)
else:
attn_output = xe_addons.sdp_causal(query_states, key_states, value_states, None)
else: else:
if q_len > 1: if q_len > 1 and q_len != kv_seq_len:
causal_mask = registered_causal_mask[ causal_mask = registered_causal_mask[
:, :, kv_seq_len - q_len:kv_seq_len, :kv_seq_len :, :, kv_seq_len - q_len:kv_seq_len, :kv_seq_len
] ]
@ -272,29 +240,10 @@ def qwen_attention_forward_registered(
else: else:
attention_mask = None attention_mask = None
if use_sdp(q_len, kv_seq_len, self.head_dim, query_states): attn_output = scaled_dot_product_attention(
import xe_addons query_states, key_states, value_states,
if use_quantize_kv: attention_mask, q_len == kv_seq_len
attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states, )
attention_mask)
else:
attn_output = xe_addons.sdp(query_states, key_states, value_states,
attention_mask)
else:
if use_quantize_kv:
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
query_states.dtype)
attn_weights = torch.matmul(query_states,
key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
if self.softmax_in_fp32:
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1,
dtype=torch.float32).to(
value_states.dtype)
else:
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, self.hidden_size) attn_output = attn_output.view(bsz, q_len, self.hidden_size)

View file

@ -37,18 +37,16 @@
# limitations under the License. # limitations under the License.
# #
import math
from typing import Optional, Tuple, List from typing import Optional, Tuple, List
import torch import torch
from transformers.cache_utils import Cache from transformers.cache_utils import Cache
from transformers.models.stablelm.modeling_stablelm import repeat_kv
from transformers.models.stablelm.modeling_stablelm import StableLmAttention, StableLmModel from transformers.models.stablelm.modeling_stablelm import StableLmAttention, StableLmModel
from ipex_llm.transformers.models.common import merge_qkv_base, attention_softmax from ipex_llm.transformers.models.common import merge_qkv_base
from ipex_llm.transformers.models.common import scaled_dot_product_attention
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb from ipex_llm.transformers.models.utils import apply_rotary_pos_emb
from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal from ipex_llm.transformers.models.utils import use_quantize_kv_cache
from ipex_llm.transformers.models.utils import restore_fp8_kv_cache, use_quantize_kv_cache
from ipex_llm.transformers.models.utils import should_use_fuse_rope from ipex_llm.transformers.models.utils import should_use_fuse_rope
from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache
@ -143,41 +141,10 @@ def stablelm_attention_forward(
# IPEX-LLM OPT: sdp # IPEX-LLM OPT: sdp
attn_weights = None attn_weights = None
if use_sdp(q_len, kv_seq_len, self.head_dim, query_states): attn_output = scaled_dot_product_attention(
import xe_addons query_states, key_states, value_states,
if isinstance(past_key_value, DynamicFp8Cache): attention_mask, q_len == kv_seq_len
attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states, )
attention_mask)
else:
attn_output = xe_addons.sdp(query_states, key_states, value_states,
attention_mask)
elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training):
import xe_addons
if isinstance(past_key_value, DynamicFp8Cache):
attn_output = xe_addons.sdp_fp8_causal(query_states, key_states,
value_states, attention_mask)
else:
attn_output = xe_addons.sdp_causal(query_states, key_states,
value_states, attention_mask)
else:
if isinstance(past_key_value, DynamicFp8Cache):
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
query_states.dtype)
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states,
key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = attention_softmax(attn_weights)
attn_weights = self.attention_dropout(attn_weights)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)