fix internlm xcomposser stream chat (#11564)
This commit is contained in:
		
							parent
							
								
									b9c66994a5
								
							
						
					
					
						commit
						a945500a98
					
				
					 2 changed files with 74 additions and 23 deletions
				
			
		| 
						 | 
				
			
			@ -1259,12 +1259,19 @@ def _optimize_post(model, lightweight_bmm=False):
 | 
			
		|||
    elif model.config.model_type == "internlmxcomposer2":
 | 
			
		||||
        modeling_module_name = model.model.__class__.__module__
 | 
			
		||||
        module = importlib.import_module(modeling_module_name)
 | 
			
		||||
        from ipex_llm.transformers.models.internlm import internlm_xcomposser2_attention_forward
 | 
			
		||||
        from ipex_llm.transformers.models.internlm import (
 | 
			
		||||
            internlm_xcomposser2_attention_forward,
 | 
			
		||||
            internlm_xcomposser2_mlp_forward,
 | 
			
		||||
            internlm_xcomposser2_model_forward_wrapper,
 | 
			
		||||
            internlm_xcomposser2_chat
 | 
			
		||||
        )
 | 
			
		||||
        convert_forward(model, module.InternLM2Attention, internlm_xcomposser2_attention_forward)
 | 
			
		||||
        from ipex_llm.transformers.models.internlm import internlm_xcomposser2_mlp_forward
 | 
			
		||||
        convert_forward(model, module.InternLM2MLP, internlm_xcomposser2_mlp_forward)
 | 
			
		||||
        convert_forward(model, module.InternLM2RMSNorm, llama_rms_norm_forward)
 | 
			
		||||
        from ipex_llm.transformers.models.internlm import internlm_xcomposser2_chat
 | 
			
		||||
        internlm_xcomposser2_model_forward = internlm_xcomposser2_model_forward_wrapper(
 | 
			
		||||
            module.InternLM2Model.forward
 | 
			
		||||
        )
 | 
			
		||||
        convert_forward(model, module.InternLM2Model, internlm_xcomposser2_model_forward)
 | 
			
		||||
        model.chat = MethodType(internlm_xcomposser2_chat, model)
 | 
			
		||||
    elif model.config.model_type == "qwen":
 | 
			
		||||
        if hasattr(model.config, "visual"):
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -310,7 +310,7 @@ def add_lora(x: torch.Tensor, result: torch.Tensor,
 | 
			
		|||
             Plora_A: torch.nn.Linear = None, Plora_B: torch.nn.Linear = None):
 | 
			
		||||
    invalidInputError(x.dim() == 3 and result.dim() == 3,
 | 
			
		||||
                      "`x` and `result` should have 3 dims")
 | 
			
		||||
    if len(im_mask) == 0 or x.size(1) == 1:
 | 
			
		||||
    if isinstance(im_mask, torch.Tensor) or len(im_mask) == 0:
 | 
			
		||||
        return result
 | 
			
		||||
    else:
 | 
			
		||||
        for start_idx, end_idx in im_mask:
 | 
			
		||||
| 
						 | 
				
			
			@ -320,6 +320,56 @@ def add_lora(x: torch.Tensor, result: torch.Tensor,
 | 
			
		|||
        return result
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def internlm_xcomposser2_model_forward_wrapper(origin_forward):
 | 
			
		||||
    def internlm_xcomposser2_model_forward(
 | 
			
		||||
        self,
 | 
			
		||||
        input_ids: torch.LongTensor = None,
 | 
			
		||||
        attention_mask: Optional[torch.Tensor] = None,
 | 
			
		||||
        position_ids: Optional[torch.LongTensor] = None,
 | 
			
		||||
        past_key_values: Optional[List[torch.FloatTensor]] = None,
 | 
			
		||||
        inputs_embeds: Optional[torch.FloatTensor] = None,
 | 
			
		||||
        use_cache: Optional[bool] = None,
 | 
			
		||||
        output_attentions: Optional[bool] = None,
 | 
			
		||||
        output_hidden_states: Optional[bool] = None,
 | 
			
		||||
        return_dict: Optional[bool] = None,
 | 
			
		||||
        **kwargs
 | 
			
		||||
    ):
 | 
			
		||||
        im_mask = kwargs.get('im_mask', None)
 | 
			
		||||
        if im_mask is None or im_mask.size(-1) <= 1 or im_mask.sum() == 0:
 | 
			
		||||
            # decoding or no image input, `im_mask` is not needed
 | 
			
		||||
            kwargs['im_mask'] = []
 | 
			
		||||
        else:
 | 
			
		||||
            # replace im_mask with start_idx and end_idx to improve performance
 | 
			
		||||
            im_mask = im_mask.cpu().flatten().tolist()
 | 
			
		||||
            length = len(im_mask)
 | 
			
		||||
            new_mask = []
 | 
			
		||||
            i = 0
 | 
			
		||||
            while i < length:
 | 
			
		||||
                while i < length and not im_mask[i]:
 | 
			
		||||
                    i = i + 1
 | 
			
		||||
                start_idx = i
 | 
			
		||||
                while i < length and im_mask[i]:
 | 
			
		||||
                    i = i + 1
 | 
			
		||||
                end_idx = i
 | 
			
		||||
                if start_idx != end_idx:
 | 
			
		||||
                    new_mask.append((start_idx, end_idx))
 | 
			
		||||
            kwargs['im_mask'] = new_mask
 | 
			
		||||
        return origin_forward(
 | 
			
		||||
            self=self,
 | 
			
		||||
            input_ids=input_ids,
 | 
			
		||||
            attention_mask=attention_mask,
 | 
			
		||||
            position_ids=position_ids,
 | 
			
		||||
            past_key_values=past_key_values,
 | 
			
		||||
            inputs_embeds=inputs_embeds,
 | 
			
		||||
            use_cache=use_cache,
 | 
			
		||||
            output_attentions=output_attentions,
 | 
			
		||||
            output_hidden_states=output_hidden_states,
 | 
			
		||||
            return_dict=return_dict,
 | 
			
		||||
            **kwargs
 | 
			
		||||
        )
 | 
			
		||||
    return internlm_xcomposser2_model_forward
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def internlm_xcomposser2_attention_forward(
 | 
			
		||||
    self,
 | 
			
		||||
    hidden_states: torch.Tensor,
 | 
			
		||||
| 
						 | 
				
			
			@ -466,32 +516,26 @@ def internlm_xcomposser2_chat(
 | 
			
		|||
    **kwargs,
 | 
			
		||||
):
 | 
			
		||||
    # ipex-llm changes start: fix device and dtype conversion
 | 
			
		||||
    # replace im_mask with start_idx and end_idx to improve performance
 | 
			
		||||
    if image is None:
 | 
			
		||||
        inputs = self.build_inputs(tokenizer, query, history, meta_instruction)
 | 
			
		||||
        im_mask = []
 | 
			
		||||
        im_mask = torch.zeros(inputs['input_ids'].shape[:2]).bool()
 | 
			
		||||
    else:
 | 
			
		||||
        image = self.encode_img(image)
 | 
			
		||||
        inputs, im_mask = self.interleav_wrap_chat(tokenizer, query, image,
 | 
			
		||||
                                                   history, meta_instruction)
 | 
			
		||||
        mask = im_mask.cpu().flatten().tolist()
 | 
			
		||||
        length = len(mask)
 | 
			
		||||
        im_mask = []
 | 
			
		||||
        i = 0
 | 
			
		||||
        while i < length:
 | 
			
		||||
            while i < length and not mask[i]:
 | 
			
		||||
                i = i + 1
 | 
			
		||||
            start_idx = i
 | 
			
		||||
            while i < length and mask[i]:
 | 
			
		||||
                i = i + 1
 | 
			
		||||
            end_idx = i
 | 
			
		||||
            if start_idx != end_idx:
 | 
			
		||||
                im_mask.append((start_idx, end_idx))
 | 
			
		||||
 | 
			
		||||
    inputs = {
 | 
			
		||||
        k: v.to(device=self.device, dtype=self.dtype)
 | 
			
		||||
        for k, v in inputs.items() if torch.is_tensor(v)
 | 
			
		||||
    }
 | 
			
		||||
    new_inputs = {}
 | 
			
		||||
    for k, v in inputs.items():
 | 
			
		||||
        if torch.is_tensor(v):
 | 
			
		||||
            if v.dtype.is_floating_point:
 | 
			
		||||
                new_inputs[k] = v.to(device=self.device, dtype=self.dtype)
 | 
			
		||||
            else:
 | 
			
		||||
                # input_ids, don't convert its dtype
 | 
			
		||||
                new_inputs[k] = v.to(device=self.device)
 | 
			
		||||
        else:
 | 
			
		||||
            new_inputs[k] = v
 | 
			
		||||
    inputs = new_inputs
 | 
			
		||||
    im_mask = im_mask.to(self.device)
 | 
			
		||||
    # ipex-llm changes end
 | 
			
		||||
 | 
			
		||||
    # also add end-of-assistant token in eos token id to avoid unnecessary generation
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue