support running internlm-xcomposer2 on cpu (#11111)

This commit is contained in:
Yishuo Wang 2024-05-23 16:36:09 +08:00 committed by GitHub
parent e0f401d97d
commit 0e53f20edb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 145 additions and 2 deletions

View file

@ -725,6 +725,10 @@ def _optimize_pre(model):
# For stablelm-zephyr-3b and stablelm-2-zephyr-1_6b
from ipex_llm.transformers.models.stablelm import merge_qkv
model.apply(merge_qkv)
# for internlm-xcomposer2-vl
if model.config.model_type == "internlmxcomposer2":
from ipex_llm.transformers.models.internlm import pre_process_attn_and_mlp
model.apply(pre_process_attn_and_mlp)
return model
@ -1217,6 +1221,13 @@ def _optimize_post(model, lightweight_bmm=False):
module.InternLMRMSNorm,
llama_rms_norm_forward
)
elif model.config.model_type == "internlmxcomposer2":
modeling_module_name = model.model.__class__.__module__
module = importlib.import_module(modeling_module_name)
from ipex_llm.transformers.models.internlm import internlm_xcomposser2_attention_forward
convert_forward(model, module.InternLM2Attention, internlm_xcomposser2_attention_forward)
from ipex_llm.transformers.models.internlm import internlm_xcomposser2_mlp_forward
convert_forward(model, module.InternLM2MLP, internlm_xcomposser2_mlp_forward)
elif model.config.model_type == "qwen":
if hasattr(model.config, "visual"):
# for Qwen-VL-Chat

View file

@ -47,7 +47,7 @@ from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, \
append_kv_cache, is_enough_kv_cache_room_4_31
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
from einops import rearrange
import os
KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))
@ -205,7 +205,6 @@ def internlm2_attention_forward(
bsz, q_len, _ = hidden_states.size()
qkv_states = self.wqkv(hidden_states)
from einops import rearrange
qkv_states = rearrange(
qkv_states,
"b q (h gs d) -> b q h gs d",
@ -291,3 +290,136 @@ def internlm2_attention_forward(
attn_weights = None
return attn_output, attn_weights, past_key_value
def pre_process_attn_and_mlp(module: torch.nn.Module):
if module.__class__.__name__ == "InternLM2Attention":
module.wqkv_lora_scaling = module.wqkv.lora_scaling
module.wqkv_Plora_A = module.wqkv.Plora_A
module.wqkv_Plora_B = module.wqkv.Plora_B
del module.wqkv.Plora_A
del module.wqkv.Plora_B
module.wo_lora_scaling = module.wo.lora_scaling
module.wo_Plora_A = module.wo.Plora_A
module.wo_Plora_B = module.wo.Plora_B
del module.wo.Plora_A
del module.wo.Plora_B
elif module.__class__.__name__ == "InternLM2MLP":
module.w1_lora_scaling = module.w1.lora_scaling
module.w1_Plora_A = module.w1.Plora_A
module.w1_Plora_B = module.w1.Plora_B
del module.w1.Plora_A
del module.w1.Plora_B
module.w2_lora_scaling = module.w2.lora_scaling
module.w2_Plora_A = module.w2.Plora_A
module.w2_Plora_B = module.w2.Plora_B
del module.w2.Plora_A
del module.w2.Plora_B
module.w3_lora_scaling = module.w3.lora_scaling
module.w3_Plora_A = module.w3.Plora_A
module.w3_Plora_B = module.w3.Plora_B
del module.w3.Plora_A
del module.w3.Plora_B
def add_lora(x: torch.Tensor, result: torch.Tensor,
im_mask: torch.Tensor = None, lora_scaling: float = 0,
Plora_A: torch.nn.Linear = None, Plora_B: torch.nn.Linear = None):
if im_mask is not None and torch.sum(im_mask) > 0:
part_x = x[im_mask]
result[im_mask] += Plora_B(Plora_A(part_x) * lora_scaling)
return result
def internlm_xcomposser2_attention_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
im_mask: Optional[Tuple[torch.Tensor]] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
qkv_states = self.wqkv(hidden_states)
qkv_states = add_lora(hidden_states, qkv_states, im_mask, self.wqkv_lora_scaling,
self.wqkv_Plora_A, self.wqkv_Plora_B)
qkv_states = rearrange(
qkv_states,
'b q (h gs d) -> b q h gs d',
gs=2 + self.num_key_value_groups,
d=self.head_dim,
)
query_states = qkv_states[..., :self.num_key_value_groups, :]
query_states = rearrange(query_states, 'b q h gs d -> b q (h gs) d')
key_states = qkv_states[..., -2, :]
value_states = qkv_states[..., -1, :]
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids, "internlm")
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose(
2, 3)) / math.sqrt(self.head_dim)
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(
attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output_2 = self.wo(attn_output)
attn_output = add_lora(attn_output, attn_output_2, im_mask, self.wo_lora_scaling,
self.wo_Plora_A, self.wo_Plora_B)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
def internlm_xcomposser2_mlp_forward(
self,
x: torch.Tensor,
im_mask: Optional[Tuple[torch.Tensor]] = None,
):
w1 = self.w1(x)
w1 = add_lora(x, w1, im_mask, self.w1_lora_scaling, self.w1_Plora_A, self.w1_Plora_B)
w3 = self.w3(x)
w3 = add_lora(x, w3, im_mask, self.w3_lora_scaling, self.w3_Plora_A, self.w3_Plora_B)
x = self.act_fn(w1) * w3
w2 = self.w2(x)
w2 = add_lora(x, w2, im_mask, self.w2_lora_scaling, self.w2_Plora_A, self.w2_Plora_B)
return w2