support running internlm-xcomposer2 on cpu (#11111)
This commit is contained in:
parent
e0f401d97d
commit
0e53f20edb
2 changed files with 145 additions and 2 deletions
|
|
@ -725,6 +725,10 @@ def _optimize_pre(model):
|
||||||
# For stablelm-zephyr-3b and stablelm-2-zephyr-1_6b
|
# For stablelm-zephyr-3b and stablelm-2-zephyr-1_6b
|
||||||
from ipex_llm.transformers.models.stablelm import merge_qkv
|
from ipex_llm.transformers.models.stablelm import merge_qkv
|
||||||
model.apply(merge_qkv)
|
model.apply(merge_qkv)
|
||||||
|
# for internlm-xcomposer2-vl
|
||||||
|
if model.config.model_type == "internlmxcomposer2":
|
||||||
|
from ipex_llm.transformers.models.internlm import pre_process_attn_and_mlp
|
||||||
|
model.apply(pre_process_attn_and_mlp)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
@ -1217,6 +1221,13 @@ def _optimize_post(model, lightweight_bmm=False):
|
||||||
module.InternLMRMSNorm,
|
module.InternLMRMSNorm,
|
||||||
llama_rms_norm_forward
|
llama_rms_norm_forward
|
||||||
)
|
)
|
||||||
|
elif model.config.model_type == "internlmxcomposer2":
|
||||||
|
modeling_module_name = model.model.__class__.__module__
|
||||||
|
module = importlib.import_module(modeling_module_name)
|
||||||
|
from ipex_llm.transformers.models.internlm import internlm_xcomposser2_attention_forward
|
||||||
|
convert_forward(model, module.InternLM2Attention, internlm_xcomposser2_attention_forward)
|
||||||
|
from ipex_llm.transformers.models.internlm import internlm_xcomposser2_mlp_forward
|
||||||
|
convert_forward(model, module.InternLM2MLP, internlm_xcomposser2_mlp_forward)
|
||||||
elif model.config.model_type == "qwen":
|
elif model.config.model_type == "qwen":
|
||||||
if hasattr(model.config, "visual"):
|
if hasattr(model.config, "visual"):
|
||||||
# for Qwen-VL-Chat
|
# for Qwen-VL-Chat
|
||||||
|
|
|
||||||
|
|
@ -47,7 +47,7 @@ from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, \
|
||||||
append_kv_cache, is_enough_kv_cache_room_4_31
|
append_kv_cache, is_enough_kv_cache_room_4_31
|
||||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb
|
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb
|
||||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
|
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
|
||||||
|
from einops import rearrange
|
||||||
import os
|
import os
|
||||||
|
|
||||||
KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))
|
KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))
|
||||||
|
|
@ -205,7 +205,6 @@ def internlm2_attention_forward(
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
qkv_states = self.wqkv(hidden_states)
|
qkv_states = self.wqkv(hidden_states)
|
||||||
from einops import rearrange
|
|
||||||
qkv_states = rearrange(
|
qkv_states = rearrange(
|
||||||
qkv_states,
|
qkv_states,
|
||||||
"b q (h gs d) -> b q h gs d",
|
"b q (h gs d) -> b q h gs d",
|
||||||
|
|
@ -291,3 +290,136 @@ def internlm2_attention_forward(
|
||||||
attn_weights = None
|
attn_weights = None
|
||||||
|
|
||||||
return attn_output, attn_weights, past_key_value
|
return attn_output, attn_weights, past_key_value
|
||||||
|
|
||||||
|
|
||||||
|
def pre_process_attn_and_mlp(module: torch.nn.Module):
|
||||||
|
if module.__class__.__name__ == "InternLM2Attention":
|
||||||
|
module.wqkv_lora_scaling = module.wqkv.lora_scaling
|
||||||
|
module.wqkv_Plora_A = module.wqkv.Plora_A
|
||||||
|
module.wqkv_Plora_B = module.wqkv.Plora_B
|
||||||
|
del module.wqkv.Plora_A
|
||||||
|
del module.wqkv.Plora_B
|
||||||
|
|
||||||
|
module.wo_lora_scaling = module.wo.lora_scaling
|
||||||
|
module.wo_Plora_A = module.wo.Plora_A
|
||||||
|
module.wo_Plora_B = module.wo.Plora_B
|
||||||
|
del module.wo.Plora_A
|
||||||
|
del module.wo.Plora_B
|
||||||
|
|
||||||
|
elif module.__class__.__name__ == "InternLM2MLP":
|
||||||
|
module.w1_lora_scaling = module.w1.lora_scaling
|
||||||
|
module.w1_Plora_A = module.w1.Plora_A
|
||||||
|
module.w1_Plora_B = module.w1.Plora_B
|
||||||
|
del module.w1.Plora_A
|
||||||
|
del module.w1.Plora_B
|
||||||
|
|
||||||
|
module.w2_lora_scaling = module.w2.lora_scaling
|
||||||
|
module.w2_Plora_A = module.w2.Plora_A
|
||||||
|
module.w2_Plora_B = module.w2.Plora_B
|
||||||
|
del module.w2.Plora_A
|
||||||
|
del module.w2.Plora_B
|
||||||
|
|
||||||
|
module.w3_lora_scaling = module.w3.lora_scaling
|
||||||
|
module.w3_Plora_A = module.w3.Plora_A
|
||||||
|
module.w3_Plora_B = module.w3.Plora_B
|
||||||
|
del module.w3.Plora_A
|
||||||
|
del module.w3.Plora_B
|
||||||
|
|
||||||
|
|
||||||
|
def add_lora(x: torch.Tensor, result: torch.Tensor,
|
||||||
|
im_mask: torch.Tensor = None, lora_scaling: float = 0,
|
||||||
|
Plora_A: torch.nn.Linear = None, Plora_B: torch.nn.Linear = None):
|
||||||
|
if im_mask is not None and torch.sum(im_mask) > 0:
|
||||||
|
part_x = x[im_mask]
|
||||||
|
result[im_mask] += Plora_B(Plora_A(part_x) * lora_scaling)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def internlm_xcomposser2_attention_forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False,
|
||||||
|
im_mask: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
qkv_states = self.wqkv(hidden_states)
|
||||||
|
qkv_states = add_lora(hidden_states, qkv_states, im_mask, self.wqkv_lora_scaling,
|
||||||
|
self.wqkv_Plora_A, self.wqkv_Plora_B)
|
||||||
|
|
||||||
|
qkv_states = rearrange(
|
||||||
|
qkv_states,
|
||||||
|
'b q (h gs d) -> b q h gs d',
|
||||||
|
gs=2 + self.num_key_value_groups,
|
||||||
|
d=self.head_dim,
|
||||||
|
)
|
||||||
|
|
||||||
|
query_states = qkv_states[..., :self.num_key_value_groups, :]
|
||||||
|
query_states = rearrange(query_states, 'b q h gs d -> b q (h gs) d')
|
||||||
|
key_states = qkv_states[..., -2, :]
|
||||||
|
value_states = qkv_states[..., -1, :]
|
||||||
|
|
||||||
|
query_states = query_states.transpose(1, 2)
|
||||||
|
key_states = key_states.transpose(1, 2)
|
||||||
|
value_states = value_states.transpose(1, 2)
|
||||||
|
|
||||||
|
kv_seq_len = key_states.shape[-2]
|
||||||
|
if past_key_value is not None:
|
||||||
|
kv_seq_len += past_key_value[0].shape[-2]
|
||||||
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||||
|
query_states, key_states = apply_rotary_pos_emb(
|
||||||
|
query_states, key_states, cos, sin, position_ids, "internlm")
|
||||||
|
|
||||||
|
if past_key_value is not None:
|
||||||
|
# reuse k, v, self_attention
|
||||||
|
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||||
|
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||||
|
|
||||||
|
past_key_value = (key_states, value_states) if use_cache else None
|
||||||
|
|
||||||
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
|
|
||||||
|
attn_weights = torch.matmul(query_states, key_states.transpose(
|
||||||
|
2, 3)) / math.sqrt(self.head_dim)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
attn_weights = attn_weights + attention_mask
|
||||||
|
|
||||||
|
# upcast attention to fp32
|
||||||
|
attn_weights = nn.functional.softmax(
|
||||||
|
attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||||
|
attn_output = torch.matmul(attn_weights, value_states)
|
||||||
|
|
||||||
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||||
|
|
||||||
|
attn_output_2 = self.wo(attn_output)
|
||||||
|
|
||||||
|
attn_output = add_lora(attn_output, attn_output_2, im_mask, self.wo_lora_scaling,
|
||||||
|
self.wo_Plora_A, self.wo_Plora_B)
|
||||||
|
|
||||||
|
if not output_attentions:
|
||||||
|
attn_weights = None
|
||||||
|
|
||||||
|
return attn_output, attn_weights, past_key_value
|
||||||
|
|
||||||
|
|
||||||
|
def internlm_xcomposser2_mlp_forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
im_mask: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
):
|
||||||
|
w1 = self.w1(x)
|
||||||
|
w1 = add_lora(x, w1, im_mask, self.w1_lora_scaling, self.w1_Plora_A, self.w1_Plora_B)
|
||||||
|
w3 = self.w3(x)
|
||||||
|
w3 = add_lora(x, w3, im_mask, self.w3_lora_scaling, self.w3_Plora_A, self.w3_Plora_B)
|
||||||
|
x = self.act_fn(w1) * w3
|
||||||
|
w2 = self.w2(x)
|
||||||
|
w2 = add_lora(x, w2, im_mask, self.w2_lora_scaling, self.w2_Plora_A, self.w2_Plora_B)
|
||||||
|
return w2
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue