optimize internlm xcomposer2 performance (#11550)
This commit is contained in:
parent
3c16c9f725
commit
82f9514303
1 changed files with 27 additions and 6 deletions
|
|
@ -42,6 +42,7 @@ from typing import Optional, Tuple, List
|
|||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from ipex_llm.utils.common.log4Error import invalidInputError
|
||||
from ipex_llm.transformers.models.utils import should_use_fuse_rope, apply_rotary_pos_emb
|
||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_cache_freq_xpu
|
||||
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
|
||||
|
|
@ -307,10 +308,16 @@ def pre_process_attn_and_mlp(module: torch.nn.Module):
|
|||
def add_lora(x: torch.Tensor, result: torch.Tensor,
|
||||
im_mask: torch.Tensor = None, lora_scaling: float = 0,
|
||||
Plora_A: torch.nn.Linear = None, Plora_B: torch.nn.Linear = None):
|
||||
if im_mask is not None and torch.sum(im_mask) > 0:
|
||||
part_x = x[im_mask]
|
||||
result[im_mask] += Plora_B(Plora_A(part_x) * lora_scaling)
|
||||
return result
|
||||
invalidInputError(x.dim() == 3 and result.dim() == 3,
|
||||
"`x` and `result` should have 3 dims")
|
||||
if len(im_mask) == 0 or x.size(1) == 1:
|
||||
return result
|
||||
else:
|
||||
for start_idx, end_idx in im_mask:
|
||||
result[:, start_idx:end_idx, :] += Plora_B(
|
||||
Plora_A(x[:, start_idx:end_idx, :]) * lora_scaling
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def internlm_xcomposser2_attention_forward(
|
||||
|
|
@ -457,18 +464,32 @@ def internlm_xcomposser2_chat(
|
|||
**kwargs,
|
||||
):
|
||||
# ipex-llm changes start: fix device and dtype conversion
|
||||
# replace im_mask with start_idx and end_idx to improve performance
|
||||
if image is None:
|
||||
inputs = self.build_inputs(tokenizer, query, history, meta_instruction)
|
||||
im_mask = torch.zeros(inputs['input_ids'].shape[:2]).bool()
|
||||
im_mask = []
|
||||
else:
|
||||
image = self.encode_img(image)
|
||||
inputs, im_mask = self.interleav_wrap_chat(tokenizer, query, image,
|
||||
history, meta_instruction)
|
||||
mask = im_mask.cpu().flatten().tolist()
|
||||
length = len(mask)
|
||||
im_mask = []
|
||||
i = 0
|
||||
while i < length:
|
||||
while i < length and not mask[i]:
|
||||
i = i + 1
|
||||
start_idx = i
|
||||
while i < length and mask[i]:
|
||||
i = i + 1
|
||||
end_idx = i
|
||||
if start_idx != end_idx:
|
||||
im_mask.append((start_idx, end_idx))
|
||||
|
||||
inputs = {
|
||||
k: v.to(device=self.device, dtype=self.dtype)
|
||||
for k, v in inputs.items() if torch.is_tensor(v)
|
||||
}
|
||||
im_mask = im_mask.to(self.device)
|
||||
# ipex-llm changes end
|
||||
|
||||
# also add end-of-assistant token in eos token id to avoid unnecessary generation
|
||||
|
|
|
|||
Loading…
Reference in a new issue