support running internlm-xcomposer2 on cpu (#11111)
This commit is contained in:
parent
e0f401d97d
commit
0e53f20edb
2 changed files with 145 additions and 2 deletions
|
|
@ -725,6 +725,10 @@ def _optimize_pre(model):
|
|||
# For stablelm-zephyr-3b and stablelm-2-zephyr-1_6b
|
||||
from ipex_llm.transformers.models.stablelm import merge_qkv
|
||||
model.apply(merge_qkv)
|
||||
# for internlm-xcomposer2-vl
|
||||
if model.config.model_type == "internlmxcomposer2":
|
||||
from ipex_llm.transformers.models.internlm import pre_process_attn_and_mlp
|
||||
model.apply(pre_process_attn_and_mlp)
|
||||
|
||||
return model
|
||||
|
||||
|
|
@ -1217,6 +1221,13 @@ def _optimize_post(model, lightweight_bmm=False):
|
|||
module.InternLMRMSNorm,
|
||||
llama_rms_norm_forward
|
||||
)
|
||||
elif model.config.model_type == "internlmxcomposer2":
|
||||
modeling_module_name = model.model.__class__.__module__
|
||||
module = importlib.import_module(modeling_module_name)
|
||||
from ipex_llm.transformers.models.internlm import internlm_xcomposser2_attention_forward
|
||||
convert_forward(model, module.InternLM2Attention, internlm_xcomposser2_attention_forward)
|
||||
from ipex_llm.transformers.models.internlm import internlm_xcomposser2_mlp_forward
|
||||
convert_forward(model, module.InternLM2MLP, internlm_xcomposser2_mlp_forward)
|
||||
elif model.config.model_type == "qwen":
|
||||
if hasattr(model.config, "visual"):
|
||||
# for Qwen-VL-Chat
|
||||
|
|
|
|||
|
|
@ -47,7 +47,7 @@ from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, \
|
|||
append_kv_cache, is_enough_kv_cache_room_4_31
|
||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb
|
||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
|
||||
|
||||
from einops import rearrange
|
||||
import os
|
||||
|
||||
KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))
|
||||
|
|
@ -205,7 +205,6 @@ def internlm2_attention_forward(
|
|||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
qkv_states = self.wqkv(hidden_states)
|
||||
from einops import rearrange
|
||||
qkv_states = rearrange(
|
||||
qkv_states,
|
||||
"b q (h gs d) -> b q h gs d",
|
||||
|
|
@ -291,3 +290,136 @@ def internlm2_attention_forward(
|
|||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
def pre_process_attn_and_mlp(module: torch.nn.Module):
|
||||
if module.__class__.__name__ == "InternLM2Attention":
|
||||
module.wqkv_lora_scaling = module.wqkv.lora_scaling
|
||||
module.wqkv_Plora_A = module.wqkv.Plora_A
|
||||
module.wqkv_Plora_B = module.wqkv.Plora_B
|
||||
del module.wqkv.Plora_A
|
||||
del module.wqkv.Plora_B
|
||||
|
||||
module.wo_lora_scaling = module.wo.lora_scaling
|
||||
module.wo_Plora_A = module.wo.Plora_A
|
||||
module.wo_Plora_B = module.wo.Plora_B
|
||||
del module.wo.Plora_A
|
||||
del module.wo.Plora_B
|
||||
|
||||
elif module.__class__.__name__ == "InternLM2MLP":
|
||||
module.w1_lora_scaling = module.w1.lora_scaling
|
||||
module.w1_Plora_A = module.w1.Plora_A
|
||||
module.w1_Plora_B = module.w1.Plora_B
|
||||
del module.w1.Plora_A
|
||||
del module.w1.Plora_B
|
||||
|
||||
module.w2_lora_scaling = module.w2.lora_scaling
|
||||
module.w2_Plora_A = module.w2.Plora_A
|
||||
module.w2_Plora_B = module.w2.Plora_B
|
||||
del module.w2.Plora_A
|
||||
del module.w2.Plora_B
|
||||
|
||||
module.w3_lora_scaling = module.w3.lora_scaling
|
||||
module.w3_Plora_A = module.w3.Plora_A
|
||||
module.w3_Plora_B = module.w3.Plora_B
|
||||
del module.w3.Plora_A
|
||||
del module.w3.Plora_B
|
||||
|
||||
|
||||
def add_lora(x: torch.Tensor, result: torch.Tensor,
|
||||
im_mask: torch.Tensor = None, lora_scaling: float = 0,
|
||||
Plora_A: torch.nn.Linear = None, Plora_B: torch.nn.Linear = None):
|
||||
if im_mask is not None and torch.sum(im_mask) > 0:
|
||||
part_x = x[im_mask]
|
||||
result[im_mask] += Plora_B(Plora_A(part_x) * lora_scaling)
|
||||
return result
|
||||
|
||||
|
||||
def internlm_xcomposser2_attention_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
im_mask: Optional[Tuple[torch.Tensor]] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
qkv_states = self.wqkv(hidden_states)
|
||||
qkv_states = add_lora(hidden_states, qkv_states, im_mask, self.wqkv_lora_scaling,
|
||||
self.wqkv_Plora_A, self.wqkv_Plora_B)
|
||||
|
||||
qkv_states = rearrange(
|
||||
qkv_states,
|
||||
'b q (h gs d) -> b q h gs d',
|
||||
gs=2 + self.num_key_value_groups,
|
||||
d=self.head_dim,
|
||||
)
|
||||
|
||||
query_states = qkv_states[..., :self.num_key_value_groups, :]
|
||||
query_states = rearrange(query_states, 'b q h gs d -> b q (h gs) d')
|
||||
key_states = qkv_states[..., -2, :]
|
||||
value_states = qkv_states[..., -1, :]
|
||||
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value[0].shape[-2]
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, position_ids, "internlm")
|
||||
|
||||
if past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
|
||||
past_key_value = (key_states, value_states) if use_cache else None
|
||||
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(
|
||||
2, 3)) / math.sqrt(self.head_dim)
|
||||
|
||||
if attention_mask is not None:
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(
|
||||
attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
|
||||
attn_output_2 = self.wo(attn_output)
|
||||
|
||||
attn_output = add_lora(attn_output, attn_output_2, im_mask, self.wo_lora_scaling,
|
||||
self.wo_Plora_A, self.wo_Plora_B)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
def internlm_xcomposser2_mlp_forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
im_mask: Optional[Tuple[torch.Tensor]] = None,
|
||||
):
|
||||
w1 = self.w1(x)
|
||||
w1 = add_lora(x, w1, im_mask, self.w1_lora_scaling, self.w1_Plora_A, self.w1_Plora_B)
|
||||
w3 = self.w3(x)
|
||||
w3 = add_lora(x, w3, im_mask, self.w3_lora_scaling, self.w3_Plora_A, self.w3_Plora_B)
|
||||
x = self.act_fn(w1) * w3
|
||||
w2 = self.w2(x)
|
||||
w2 = add_lora(x, w2, im_mask, self.w2_lora_scaling, self.w2_Plora_A, self.w2_Plora_B)
|
||||
return w2
|
||||
|
|
|
|||
Loading…
Reference in a new issue