refactor internlm and internlm2 (#11274)
This commit is contained in:
parent
fac49f15e3
commit
10e480ee96
2 changed files with 124 additions and 160 deletions
|
|
@ -719,6 +719,10 @@ def _optimize_pre(model):
|
||||||
# For stablelm-zephyr-3b and stablelm-2-zephyr-1_6b
|
# For stablelm-zephyr-3b and stablelm-2-zephyr-1_6b
|
||||||
from ipex_llm.transformers.models.stablelm import merge_qkv
|
from ipex_llm.transformers.models.stablelm import merge_qkv
|
||||||
model.apply(merge_qkv)
|
model.apply(merge_qkv)
|
||||||
|
# for internlm
|
||||||
|
if model.config.model_type == "internlm":
|
||||||
|
from ipex_llm.transformers.models.internlm import merge_qkv
|
||||||
|
model.apply(merge_qkv)
|
||||||
# for internlm-xcomposer2-vl
|
# for internlm-xcomposer2-vl
|
||||||
if model.config.model_type == "internlmxcomposer2":
|
if model.config.model_type == "internlmxcomposer2":
|
||||||
from ipex_llm.transformers.models.internlm import pre_process_attn_and_mlp
|
from ipex_llm.transformers.models.internlm import pre_process_attn_and_mlp
|
||||||
|
|
@ -1167,27 +1171,14 @@ def _optimize_post(model, lightweight_bmm=False):
|
||||||
modeling_module_name = model.__class__.__module__
|
modeling_module_name = model.__class__.__module__
|
||||||
module = importlib.import_module(modeling_module_name)
|
module = importlib.import_module(modeling_module_name)
|
||||||
from ipex_llm.transformers.models.internlm import internlm_attention_forward
|
from ipex_llm.transformers.models.internlm import internlm_attention_forward
|
||||||
|
convert_forward(model, module.InternLMAttention, internlm_attention_forward)
|
||||||
|
convert_forward(model, module.InternLMRMSNorm, llama_rms_norm_forward)
|
||||||
|
elif model.config.model_type == "internlm2":
|
||||||
|
modeling_module_name = model.__class__.__module__
|
||||||
|
module = importlib.import_module(modeling_module_name)
|
||||||
from ipex_llm.transformers.models.internlm import internlm2_attention_forward
|
from ipex_llm.transformers.models.internlm import internlm2_attention_forward
|
||||||
try:
|
convert_forward(model, module.InternLM2Attention, internlm2_attention_forward)
|
||||||
convert_forward(model,
|
convert_forward(model, module.InternLM2RMSNorm, llama_rms_norm_forward)
|
||||||
module.InternLM2Attention,
|
|
||||||
internlm2_attention_forward
|
|
||||||
)
|
|
||||||
except:
|
|
||||||
convert_forward(model,
|
|
||||||
module.InternLMAttention,
|
|
||||||
internlm_attention_forward
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
convert_forward(model,
|
|
||||||
module.InternLM2RMSNorm,
|
|
||||||
llama_rms_norm_forward
|
|
||||||
)
|
|
||||||
except:
|
|
||||||
convert_forward(model,
|
|
||||||
module.InternLMRMSNorm,
|
|
||||||
llama_rms_norm_forward
|
|
||||||
)
|
|
||||||
elif model.config.model_type == "internlmxcomposer2":
|
elif model.config.model_type == "internlmxcomposer2":
|
||||||
modeling_module_name = model.model.__class__.__module__
|
modeling_module_name = model.model.__class__.__module__
|
||||||
module = importlib.import_module(modeling_module_name)
|
module = importlib.import_module(modeling_module_name)
|
||||||
|
|
|
||||||
|
|
@ -42,20 +42,35 @@ from typing import Optional, Tuple, List
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from ipex_llm.utils.common import invalidInputError
|
|
||||||
from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, \
|
|
||||||
append_kv_cache, is_enough_kv_cache_room_4_31
|
|
||||||
from ipex_llm.transformers.models.utils import should_use_fuse_rope, apply_rotary_pos_emb
|
from ipex_llm.transformers.models.utils import should_use_fuse_rope, apply_rotary_pos_emb
|
||||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
|
|
||||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_cache_freq_xpu
|
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_cache_freq_xpu
|
||||||
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
|
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
|
||||||
from ipex_llm.transformers.models.utils import update_past_key_value
|
from ipex_llm.transformers.models.utils import update_past_key_value
|
||||||
from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal
|
from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
import os
|
|
||||||
|
|
||||||
|
|
||||||
KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))
|
def merge_qkv(module: torch.nn.Module):
|
||||||
|
if module.__class__.__name__ == "InternLMAttention":
|
||||||
|
new_weight = torch.cat([
|
||||||
|
module.q_proj.weight.data,
|
||||||
|
module.k_proj.weight.data,
|
||||||
|
module.v_proj.weight.data,
|
||||||
|
], dim=0)
|
||||||
|
new_bias = torch.cat([
|
||||||
|
module.q_proj.bias.data,
|
||||||
|
module.k_proj.bias.data,
|
||||||
|
module.v_proj.bias.data,
|
||||||
|
], dim=-1)
|
||||||
|
|
||||||
|
qkv_proj = torch.nn.Linear(0, 0, bias=True)
|
||||||
|
qkv_proj.weight = torch.nn.Parameter(new_weight, requires_grad=False)
|
||||||
|
qkv_proj.bias = torch.nn.Parameter(new_bias, requires_grad=False)
|
||||||
|
qkv_proj.in_features = new_weight.size(1)
|
||||||
|
qkv_proj.out_features = new_weight.size(0)
|
||||||
|
module.qkv_proj = qkv_proj
|
||||||
|
|
||||||
|
del module.q_proj, module.k_proj, module.v_proj
|
||||||
|
|
||||||
|
|
||||||
def internlm_attention_forward(
|
def internlm_attention_forward(
|
||||||
|
|
@ -68,109 +83,69 @@ def internlm_attention_forward(
|
||||||
use_cache: bool=False,
|
use_cache: bool=False,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
device = hidden_states.device
|
|
||||||
query_states = self.q_proj(hidden_states) \
|
qkv = self.qkv_proj(hidden_states)
|
||||||
.view(bsz, q_len, self.num_heads, self.head_dim) \
|
qkv = qkv.view(bsz, q_len, self.num_heads * 3, self.head_dim)
|
||||||
.transpose(1, 2)
|
qkv = qkv.transpose(1, 2)
|
||||||
key_states = self.k_proj(hidden_states) \
|
query_states, key_states, value_states = qkv.split([self.num_heads,
|
||||||
.view(bsz, q_len, self.num_heads, self.head_dim) \
|
self.num_heads,
|
||||||
.transpose(1, 2)
|
self.num_heads], dim=1)
|
||||||
value_states = self.v_proj(hidden_states) \
|
|
||||||
.view(bsz, q_len, self.num_heads, self.head_dim) \
|
|
||||||
.transpose(1, 2)
|
|
||||||
|
|
||||||
kv_seq_len = key_states.shape[-2]
|
kv_seq_len = key_states.shape[-2]
|
||||||
enough_kv_room = True
|
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value, seq_len=kv_seq_len)
|
|
||||||
kv_seq_len += past_key_value[0].shape[-2]
|
kv_seq_len += past_key_value[0].shape[-2]
|
||||||
|
|
||||||
|
# IPEX-LLM OPT: fuse rope
|
||||||
if should_use_fuse_rope(hidden_states, position_ids, self.training):
|
if should_use_fuse_rope(hidden_states, position_ids, self.training):
|
||||||
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
|
import xe_addons
|
||||||
key_states,
|
xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
|
||||||
position_ids,
|
query_states, key_states)
|
||||||
"internlm")
|
|
||||||
else:
|
else:
|
||||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||||
query_states, key_states = apply_rotary_pos_emb(
|
query_states, key_states = apply_rotary_pos_emb(
|
||||||
query_states,
|
query_states, key_states, cos, sin, position_ids, "internlm"
|
||||||
key_states,
|
|
||||||
cos,
|
|
||||||
sin,
|
|
||||||
position_ids,
|
|
||||||
"internlm")
|
|
||||||
# [bsz, nh, t, hd]
|
|
||||||
|
|
||||||
if past_key_value is not None:
|
|
||||||
# reuse k, v, self_attention
|
|
||||||
cache_k = past_key_value[0]
|
|
||||||
cache_v = past_key_value[1]
|
|
||||||
if not enough_kv_room:
|
|
||||||
# allocate new
|
|
||||||
new_cache_k, new_cache_v = extend_kv_cache(
|
|
||||||
bsz,
|
|
||||||
self.num_heads,
|
|
||||||
self.head_dim,
|
|
||||||
cache_k.size(2),
|
|
||||||
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
|
|
||||||
dtype=cache_k.dtype,
|
|
||||||
device=device
|
|
||||||
)
|
|
||||||
new_cache_k[:] = cache_k
|
|
||||||
new_cache_v[:] = cache_v
|
|
||||||
cache_k = new_cache_k
|
|
||||||
cache_v = new_cache_v
|
|
||||||
|
|
||||||
key_states, value_states = append_kv_cache(cache_k, cache_v, key_states, value_states)
|
|
||||||
|
|
||||||
elif use_cache:
|
|
||||||
max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
|
|
||||||
new_key_states, new_value_states = init_kv_cache(
|
|
||||||
bsz,
|
|
||||||
self.num_heads,
|
|
||||||
self.head_dim,
|
|
||||||
kv_seq_len,
|
|
||||||
max_cache_length,
|
|
||||||
dtype=key_states.dtype,
|
|
||||||
device=device
|
|
||||||
)
|
)
|
||||||
new_key_states[:] = key_states
|
|
||||||
new_value_states[:] = value_states
|
|
||||||
key_states = new_key_states
|
|
||||||
value_states = new_value_states
|
|
||||||
|
|
||||||
|
# IPEX-LLM OPT: kv cache and quantzie kv cache
|
||||||
|
use_quantize_kv = use_quantize_kv_cache(self.qkv_proj, hidden_states)
|
||||||
|
key_states, value_states = update_past_key_value(
|
||||||
|
past_key_value, key_states, value_states,
|
||||||
|
kv_seq_len, use_quantize_kv, hidden_states.device
|
||||||
|
)
|
||||||
past_key_value = (key_states, value_states) if use_cache else None
|
past_key_value = (key_states, value_states) if use_cache else None
|
||||||
|
|
||||||
attn_weights = torch.matmul(query_states,
|
# IPEX-LLM OPT: sdp
|
||||||
key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
attn_weights = None
|
||||||
|
if use_sdp(q_len, kv_seq_len, self.head_dim, query_states):
|
||||||
|
import xe_addons
|
||||||
|
if use_quantize_kv:
|
||||||
|
attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states,
|
||||||
|
attention_mask)
|
||||||
|
else:
|
||||||
|
attn_output = xe_addons.sdp(query_states, key_states, value_states, attention_mask)
|
||||||
|
elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training):
|
||||||
|
import xe_addons
|
||||||
|
if use_quantize_kv:
|
||||||
|
attn_output = xe_addons.sdp_fp8_causal(query_states, key_states,
|
||||||
|
value_states, attention_mask)
|
||||||
|
else:
|
||||||
|
attn_output = xe_addons.sdp_causal(query_states, key_states,
|
||||||
|
value_states, attention_mask)
|
||||||
|
else:
|
||||||
|
if use_quantize_kv:
|
||||||
|
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
|
||||||
|
query_states.dtype)
|
||||||
|
|
||||||
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
attn_weights = torch.matmul(query_states,
|
||||||
invalidInputError(
|
key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||||
False,
|
|
||||||
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, "
|
|
||||||
f"but is {attn_weights.size()}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
attn_weights = attn_weights + attention_mask
|
||||||
invalidInputError(
|
|
||||||
False,
|
|
||||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, "
|
|
||||||
f"but is {attention_mask.size()}"
|
|
||||||
)
|
|
||||||
attn_weights = attn_weights + attention_mask
|
|
||||||
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
|
|
||||||
|
|
||||||
# upcast attention to fp32
|
# upcast attention to fp32
|
||||||
attn_weights = nn.functional.softmax(attn_weights,
|
attn_weights = nn.functional.softmax(attn_weights,
|
||||||
dim=-1, dtype=torch.float32).to(query_states.dtype)
|
dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||||
attn_output = torch.matmul(attn_weights, value_states)
|
attn_output = torch.matmul(attn_weights, value_states)
|
||||||
|
|
||||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
|
||||||
invalidInputError(
|
|
||||||
False,
|
|
||||||
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, "
|
|
||||||
f"but is {attn_output.size()}"
|
|
||||||
)
|
|
||||||
|
|
||||||
attn_output = attn_output.transpose(1, 2)
|
attn_output = attn_output.transpose(1, 2)
|
||||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||||
|
|
@ -229,62 +204,60 @@ def internlm2_attention_forward(
|
||||||
kv_seq_len = key_states.shape[-2]
|
kv_seq_len = key_states.shape[-2]
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
kv_seq_len += past_key_value[0].shape[-2]
|
kv_seq_len += past_key_value[0].shape[-2]
|
||||||
|
|
||||||
|
# IPEX-LLM OPT: fuse rope
|
||||||
if should_use_fuse_rope(hidden_states, position_ids, self.training):
|
if should_use_fuse_rope(hidden_states, position_ids, self.training):
|
||||||
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
|
import xe_addons
|
||||||
key_states,
|
xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
|
||||||
position_ids,
|
query_states, key_states)
|
||||||
"internlm")
|
|
||||||
else:
|
else:
|
||||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||||
# query_states, key_states = apply_rotary_pos_emb(query_states,
|
|
||||||
# key_states, cos, sin, position_ids)
|
|
||||||
query_states, key_states = apply_rotary_pos_emb(
|
query_states, key_states = apply_rotary_pos_emb(
|
||||||
query_states,
|
query_states, key_states, cos, sin, position_ids, "internlm"
|
||||||
key_states,
|
)
|
||||||
cos,
|
|
||||||
sin,
|
|
||||||
position_ids,
|
|
||||||
"internlm")
|
|
||||||
|
|
||||||
if past_key_value is not None:
|
|
||||||
# reuse k, v, self_attention
|
|
||||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
|
||||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
|
||||||
|
|
||||||
|
# IPEX-LLM OPT: kv cache and quantzie kv cache
|
||||||
|
use_quantize_kv = use_quantize_kv_cache(self.wqkv, hidden_states)
|
||||||
|
key_states, value_states = update_past_key_value(
|
||||||
|
past_key_value, key_states, value_states,
|
||||||
|
kv_seq_len, use_quantize_kv, hidden_states.device
|
||||||
|
)
|
||||||
past_key_value = (key_states, value_states) if use_cache else None
|
past_key_value = (key_states, value_states) if use_cache else None
|
||||||
|
|
||||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
# IPEX-LLM OPT: sdp
|
||||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
attn_weights = None
|
||||||
|
if use_sdp(q_len, kv_seq_len, self.head_dim, query_states):
|
||||||
|
import xe_addons
|
||||||
|
if use_quantize_kv:
|
||||||
|
attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states,
|
||||||
|
attention_mask)
|
||||||
|
else:
|
||||||
|
attn_output = xe_addons.sdp(query_states, key_states, value_states, attention_mask)
|
||||||
|
elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training):
|
||||||
|
import xe_addons
|
||||||
|
if use_quantize_kv:
|
||||||
|
attn_output = xe_addons.sdp_fp8_causal(query_states, key_states,
|
||||||
|
value_states, attention_mask)
|
||||||
|
else:
|
||||||
|
attn_output = xe_addons.sdp_causal(query_states, key_states,
|
||||||
|
value_states, attention_mask)
|
||||||
|
else:
|
||||||
|
if use_quantize_kv:
|
||||||
|
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
|
||||||
|
query_states.dtype)
|
||||||
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
|
|
||||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
attn_weights = torch.matmul(query_states,
|
||||||
|
key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||||
|
|
||||||
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
if attention_mask is not None:
|
||||||
invalidInputError(
|
attn_weights = attn_weights + attention_mask
|
||||||
False,
|
|
||||||
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, "
|
|
||||||
f"but is {attn_weights.size()}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
# upcast attention to fp32
|
||||||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
attn_weights = nn.functional.softmax(attn_weights,
|
||||||
invalidInputError(
|
dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||||
False,
|
attn_output = torch.matmul(attn_weights, value_states)
|
||||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, "
|
|
||||||
f"but is {attention_mask.size()}"
|
|
||||||
)
|
|
||||||
attn_weights = attn_weights + attention_mask
|
|
||||||
|
|
||||||
# upcast attention to fp32
|
|
||||||
attn_weights = nn.functional.softmax(attn_weights,
|
|
||||||
dim=-1, dtype=torch.float32).to(query_states.dtype)
|
|
||||||
attn_output = torch.matmul(attn_weights, value_states)
|
|
||||||
|
|
||||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
|
||||||
invalidInputError(
|
|
||||||
False,
|
|
||||||
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, "
|
|
||||||
f"but is {attn_output.size()}"
|
|
||||||
)
|
|
||||||
|
|
||||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue