optimize internlm xcomposer2 performance (#11550)
This commit is contained in:
		
							parent
							
								
									3c16c9f725
								
							
						
					
					
						commit
						82f9514303
					
				
					 1 changed files with 27 additions and 6 deletions
				
			
		| 
						 | 
				
			
			@ -42,6 +42,7 @@ from typing import Optional, Tuple, List
 | 
			
		|||
import torch
 | 
			
		||||
import torch.utils.checkpoint
 | 
			
		||||
from torch import nn
 | 
			
		||||
from ipex_llm.utils.common.log4Error import invalidInputError
 | 
			
		||||
from ipex_llm.transformers.models.utils import should_use_fuse_rope, apply_rotary_pos_emb
 | 
			
		||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_cache_freq_xpu
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
 | 
			
		||||
| 
						 | 
				
			
			@ -307,9 +308,15 @@ def pre_process_attn_and_mlp(module: torch.nn.Module):
 | 
			
		|||
def add_lora(x: torch.Tensor, result: torch.Tensor,
 | 
			
		||||
             im_mask: torch.Tensor = None, lora_scaling: float = 0,
 | 
			
		||||
             Plora_A: torch.nn.Linear = None, Plora_B: torch.nn.Linear = None):
 | 
			
		||||
    if im_mask is not None and torch.sum(im_mask) > 0:
 | 
			
		||||
        part_x = x[im_mask]
 | 
			
		||||
        result[im_mask] += Plora_B(Plora_A(part_x) * lora_scaling)
 | 
			
		||||
    invalidInputError(x.dim() == 3 and result.dim() == 3,
 | 
			
		||||
                      "`x` and `result` should have 3 dims")
 | 
			
		||||
    if len(im_mask) == 0 or x.size(1) == 1:
 | 
			
		||||
        return result
 | 
			
		||||
    else:
 | 
			
		||||
        for start_idx, end_idx in im_mask:
 | 
			
		||||
            result[:, start_idx:end_idx, :] += Plora_B(
 | 
			
		||||
                Plora_A(x[:, start_idx:end_idx, :]) * lora_scaling
 | 
			
		||||
            )
 | 
			
		||||
        return result
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -457,18 +464,32 @@ def internlm_xcomposser2_chat(
 | 
			
		|||
    **kwargs,
 | 
			
		||||
):
 | 
			
		||||
    # ipex-llm changes start: fix device and dtype conversion
 | 
			
		||||
    # replace im_mask with start_idx and end_idx to improve performance
 | 
			
		||||
    if image is None:
 | 
			
		||||
        inputs = self.build_inputs(tokenizer, query, history, meta_instruction)
 | 
			
		||||
        im_mask = torch.zeros(inputs['input_ids'].shape[:2]).bool()
 | 
			
		||||
        im_mask = []
 | 
			
		||||
    else:
 | 
			
		||||
        image = self.encode_img(image)
 | 
			
		||||
        inputs, im_mask = self.interleav_wrap_chat(tokenizer, query, image,
 | 
			
		||||
                                                   history, meta_instruction)
 | 
			
		||||
        mask = im_mask.cpu().flatten().tolist()
 | 
			
		||||
        length = len(mask)
 | 
			
		||||
        im_mask = []
 | 
			
		||||
        i = 0
 | 
			
		||||
        while i < length:
 | 
			
		||||
            while i < length and not mask[i]:
 | 
			
		||||
                i = i + 1
 | 
			
		||||
            start_idx = i
 | 
			
		||||
            while i < length and mask[i]:
 | 
			
		||||
                i = i + 1
 | 
			
		||||
            end_idx = i
 | 
			
		||||
            if start_idx != end_idx:
 | 
			
		||||
                im_mask.append((start_idx, end_idx))
 | 
			
		||||
 | 
			
		||||
    inputs = {
 | 
			
		||||
        k: v.to(device=self.device, dtype=self.dtype)
 | 
			
		||||
        for k, v in inputs.items() if torch.is_tensor(v)
 | 
			
		||||
    }
 | 
			
		||||
    im_mask = im_mask.to(self.device)
 | 
			
		||||
    # ipex-llm changes end
 | 
			
		||||
 | 
			
		||||
    # also add end-of-assistant token in eos token id to avoid unnecessary generation
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue