From 0e53f20edbd20abb6b3c6c2e40a24e50f03ba928 Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Thu, 23 May 2024 16:36:09 +0800 Subject: [PATCH] support running internlm-xcomposer2 on cpu (#11111) --- .../llm/src/ipex_llm/transformers/convert.py | 11 ++ .../ipex_llm/transformers/models/internlm.py | 136 +++++++++++++++++- 2 files changed, 145 insertions(+), 2 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index a6b25a8f..6e5a33a7 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -725,6 +725,10 @@ def _optimize_pre(model): # For stablelm-zephyr-3b and stablelm-2-zephyr-1_6b from ipex_llm.transformers.models.stablelm import merge_qkv model.apply(merge_qkv) + # for internlm-xcomposer2-vl + if model.config.model_type == "internlmxcomposer2": + from ipex_llm.transformers.models.internlm import pre_process_attn_and_mlp + model.apply(pre_process_attn_and_mlp) return model @@ -1217,6 +1221,13 @@ def _optimize_post(model, lightweight_bmm=False): module.InternLMRMSNorm, llama_rms_norm_forward ) + elif model.config.model_type == "internlmxcomposer2": + modeling_module_name = model.model.__class__.__module__ + module = importlib.import_module(modeling_module_name) + from ipex_llm.transformers.models.internlm import internlm_xcomposser2_attention_forward + convert_forward(model, module.InternLM2Attention, internlm_xcomposser2_attention_forward) + from ipex_llm.transformers.models.internlm import internlm_xcomposser2_mlp_forward + convert_forward(model, module.InternLM2MLP, internlm_xcomposser2_mlp_forward) elif model.config.model_type == "qwen": if hasattr(model.config, "visual"): # for Qwen-VL-Chat diff --git a/python/llm/src/ipex_llm/transformers/models/internlm.py b/python/llm/src/ipex_llm/transformers/models/internlm.py index f23f40b1..4b19faf1 100644 --- a/python/llm/src/ipex_llm/transformers/models/internlm.py +++ b/python/llm/src/ipex_llm/transformers/models/internlm.py @@ -47,7 +47,7 @@ from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, \ append_kv_cache, is_enough_kv_cache_room_4_31 from ipex_llm.transformers.models.utils import apply_rotary_pos_emb from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu - +from einops import rearrange import os KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)) @@ -205,7 +205,6 @@ def internlm2_attention_forward( bsz, q_len, _ = hidden_states.size() qkv_states = self.wqkv(hidden_states) - from einops import rearrange qkv_states = rearrange( qkv_states, "b q (h gs d) -> b q h gs d", @@ -291,3 +290,136 @@ def internlm2_attention_forward( attn_weights = None return attn_output, attn_weights, past_key_value + + +def pre_process_attn_and_mlp(module: torch.nn.Module): + if module.__class__.__name__ == "InternLM2Attention": + module.wqkv_lora_scaling = module.wqkv.lora_scaling + module.wqkv_Plora_A = module.wqkv.Plora_A + module.wqkv_Plora_B = module.wqkv.Plora_B + del module.wqkv.Plora_A + del module.wqkv.Plora_B + + module.wo_lora_scaling = module.wo.lora_scaling + module.wo_Plora_A = module.wo.Plora_A + module.wo_Plora_B = module.wo.Plora_B + del module.wo.Plora_A + del module.wo.Plora_B + + elif module.__class__.__name__ == "InternLM2MLP": + module.w1_lora_scaling = module.w1.lora_scaling + module.w1_Plora_A = module.w1.Plora_A + module.w1_Plora_B = module.w1.Plora_B + del module.w1.Plora_A + del module.w1.Plora_B + + module.w2_lora_scaling = module.w2.lora_scaling + module.w2_Plora_A = module.w2.Plora_A + module.w2_Plora_B = module.w2.Plora_B + del module.w2.Plora_A + del module.w2.Plora_B + + module.w3_lora_scaling = module.w3.lora_scaling + module.w3_Plora_A = module.w3.Plora_A + module.w3_Plora_B = module.w3.Plora_B + del module.w3.Plora_A + del module.w3.Plora_B + + +def add_lora(x: torch.Tensor, result: torch.Tensor, + im_mask: torch.Tensor = None, lora_scaling: float = 0, + Plora_A: torch.nn.Linear = None, Plora_B: torch.nn.Linear = None): + if im_mask is not None and torch.sum(im_mask) > 0: + part_x = x[im_mask] + result[im_mask] += Plora_B(Plora_A(part_x) * lora_scaling) + return result + + +def internlm_xcomposser2_attention_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + im_mask: Optional[Tuple[torch.Tensor]] = None, + **kwargs, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + qkv_states = self.wqkv(hidden_states) + qkv_states = add_lora(hidden_states, qkv_states, im_mask, self.wqkv_lora_scaling, + self.wqkv_Plora_A, self.wqkv_Plora_B) + + qkv_states = rearrange( + qkv_states, + 'b q (h gs d) -> b q h gs d', + gs=2 + self.num_key_value_groups, + d=self.head_dim, + ) + + query_states = qkv_states[..., :self.num_key_value_groups, :] + query_states = rearrange(query_states, 'b q h gs d -> b q (h gs) d') + key_states = qkv_states[..., -2, :] + value_states = qkv_states[..., -1, :] + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin, position_ids, "internlm") + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose( + 2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output_2 = self.wo(attn_output) + + attn_output = add_lora(attn_output, attn_output_2, im_mask, self.wo_lora_scaling, + self.wo_Plora_A, self.wo_Plora_B) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +def internlm_xcomposser2_mlp_forward( + self, + x: torch.Tensor, + im_mask: Optional[Tuple[torch.Tensor]] = None, +): + w1 = self.w1(x) + w1 = add_lora(x, w1, im_mask, self.w1_lora_scaling, self.w1_Plora_A, self.w1_Plora_B) + w3 = self.w3(x) + w3 = add_lora(x, w3, im_mask, self.w3_lora_scaling, self.w3_Plora_A, self.w3_Plora_B) + x = self.act_fn(w1) * w3 + w2 = self.w2(x) + w2 = add_lora(x, w2, im_mask, self.w2_lora_scaling, self.w2_Plora_A, self.w2_Plora_B) + return w2