diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index 7a02cc05..c0d94b6e 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -1259,12 +1259,19 @@ def _optimize_post(model, lightweight_bmm=False): elif model.config.model_type == "internlmxcomposer2": modeling_module_name = model.model.__class__.__module__ module = importlib.import_module(modeling_module_name) - from ipex_llm.transformers.models.internlm import internlm_xcomposser2_attention_forward + from ipex_llm.transformers.models.internlm import ( + internlm_xcomposser2_attention_forward, + internlm_xcomposser2_mlp_forward, + internlm_xcomposser2_model_forward_wrapper, + internlm_xcomposser2_chat + ) convert_forward(model, module.InternLM2Attention, internlm_xcomposser2_attention_forward) - from ipex_llm.transformers.models.internlm import internlm_xcomposser2_mlp_forward convert_forward(model, module.InternLM2MLP, internlm_xcomposser2_mlp_forward) convert_forward(model, module.InternLM2RMSNorm, llama_rms_norm_forward) - from ipex_llm.transformers.models.internlm import internlm_xcomposser2_chat + internlm_xcomposser2_model_forward = internlm_xcomposser2_model_forward_wrapper( + module.InternLM2Model.forward + ) + convert_forward(model, module.InternLM2Model, internlm_xcomposser2_model_forward) model.chat = MethodType(internlm_xcomposser2_chat, model) elif model.config.model_type == "qwen": if hasattr(model.config, "visual"): diff --git a/python/llm/src/ipex_llm/transformers/models/internlm.py b/python/llm/src/ipex_llm/transformers/models/internlm.py index df0ffd5d..a3dd0cb8 100644 --- a/python/llm/src/ipex_llm/transformers/models/internlm.py +++ b/python/llm/src/ipex_llm/transformers/models/internlm.py @@ -310,7 +310,7 @@ def add_lora(x: torch.Tensor, result: torch.Tensor, Plora_A: torch.nn.Linear = None, Plora_B: torch.nn.Linear = None): invalidInputError(x.dim() == 3 and result.dim() == 3, "`x` and `result` should have 3 dims") - if len(im_mask) == 0 or x.size(1) == 1: + if isinstance(im_mask, torch.Tensor) or len(im_mask) == 0: return result else: for start_idx, end_idx in im_mask: @@ -320,6 +320,56 @@ def add_lora(x: torch.Tensor, result: torch.Tensor, return result +def internlm_xcomposser2_model_forward_wrapper(origin_forward): + def internlm_xcomposser2_model_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs + ): + im_mask = kwargs.get('im_mask', None) + if im_mask is None or im_mask.size(-1) <= 1 or im_mask.sum() == 0: + # decoding or no image input, `im_mask` is not needed + kwargs['im_mask'] = [] + else: + # replace im_mask with start_idx and end_idx to improve performance + im_mask = im_mask.cpu().flatten().tolist() + length = len(im_mask) + new_mask = [] + i = 0 + while i < length: + while i < length and not im_mask[i]: + i = i + 1 + start_idx = i + while i < length and im_mask[i]: + i = i + 1 + end_idx = i + if start_idx != end_idx: + new_mask.append((start_idx, end_idx)) + kwargs['im_mask'] = new_mask + return origin_forward( + self=self, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + **kwargs + ) + return internlm_xcomposser2_model_forward + + def internlm_xcomposser2_attention_forward( self, hidden_states: torch.Tensor, @@ -466,32 +516,26 @@ def internlm_xcomposser2_chat( **kwargs, ): # ipex-llm changes start: fix device and dtype conversion - # replace im_mask with start_idx and end_idx to improve performance if image is None: inputs = self.build_inputs(tokenizer, query, history, meta_instruction) - im_mask = [] + im_mask = torch.zeros(inputs['input_ids'].shape[:2]).bool() else: image = self.encode_img(image) inputs, im_mask = self.interleav_wrap_chat(tokenizer, query, image, history, meta_instruction) - mask = im_mask.cpu().flatten().tolist() - length = len(mask) - im_mask = [] - i = 0 - while i < length: - while i < length and not mask[i]: - i = i + 1 - start_idx = i - while i < length and mask[i]: - i = i + 1 - end_idx = i - if start_idx != end_idx: - im_mask.append((start_idx, end_idx)) - inputs = { - k: v.to(device=self.device, dtype=self.dtype) - for k, v in inputs.items() if torch.is_tensor(v) - } + new_inputs = {} + for k, v in inputs.items(): + if torch.is_tensor(v): + if v.dtype.is_floating_point: + new_inputs[k] = v.to(device=self.device, dtype=self.dtype) + else: + # input_ids, don't convert its dtype + new_inputs[k] = v.to(device=self.device) + else: + new_inputs[k] = v + inputs = new_inputs + im_mask = im_mask.to(self.device) # ipex-llm changes end # also add end-of-assistant token in eos token id to avoid unnecessary generation