optimize internlm xcomposer2 performance (#11550)
This commit is contained in:
parent
3c16c9f725
commit
82f9514303
1 changed files with 27 additions and 6 deletions
|
|
@ -42,6 +42,7 @@ from typing import Optional, Tuple, List
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
from ipex_llm.utils.common.log4Error import invalidInputError
|
||||||
from ipex_llm.transformers.models.utils import should_use_fuse_rope, apply_rotary_pos_emb
|
from ipex_llm.transformers.models.utils import should_use_fuse_rope, apply_rotary_pos_emb
|
||||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_cache_freq_xpu
|
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_cache_freq_xpu
|
||||||
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
|
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
|
||||||
|
|
@ -307,9 +308,15 @@ def pre_process_attn_and_mlp(module: torch.nn.Module):
|
||||||
def add_lora(x: torch.Tensor, result: torch.Tensor,
|
def add_lora(x: torch.Tensor, result: torch.Tensor,
|
||||||
im_mask: torch.Tensor = None, lora_scaling: float = 0,
|
im_mask: torch.Tensor = None, lora_scaling: float = 0,
|
||||||
Plora_A: torch.nn.Linear = None, Plora_B: torch.nn.Linear = None):
|
Plora_A: torch.nn.Linear = None, Plora_B: torch.nn.Linear = None):
|
||||||
if im_mask is not None and torch.sum(im_mask) > 0:
|
invalidInputError(x.dim() == 3 and result.dim() == 3,
|
||||||
part_x = x[im_mask]
|
"`x` and `result` should have 3 dims")
|
||||||
result[im_mask] += Plora_B(Plora_A(part_x) * lora_scaling)
|
if len(im_mask) == 0 or x.size(1) == 1:
|
||||||
|
return result
|
||||||
|
else:
|
||||||
|
for start_idx, end_idx in im_mask:
|
||||||
|
result[:, start_idx:end_idx, :] += Plora_B(
|
||||||
|
Plora_A(x[:, start_idx:end_idx, :]) * lora_scaling
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -457,18 +464,32 @@ def internlm_xcomposser2_chat(
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
# ipex-llm changes start: fix device and dtype conversion
|
# ipex-llm changes start: fix device and dtype conversion
|
||||||
|
# replace im_mask with start_idx and end_idx to improve performance
|
||||||
if image is None:
|
if image is None:
|
||||||
inputs = self.build_inputs(tokenizer, query, history, meta_instruction)
|
inputs = self.build_inputs(tokenizer, query, history, meta_instruction)
|
||||||
im_mask = torch.zeros(inputs['input_ids'].shape[:2]).bool()
|
im_mask = []
|
||||||
else:
|
else:
|
||||||
image = self.encode_img(image)
|
image = self.encode_img(image)
|
||||||
inputs, im_mask = self.interleav_wrap_chat(tokenizer, query, image,
|
inputs, im_mask = self.interleav_wrap_chat(tokenizer, query, image,
|
||||||
history, meta_instruction)
|
history, meta_instruction)
|
||||||
|
mask = im_mask.cpu().flatten().tolist()
|
||||||
|
length = len(mask)
|
||||||
|
im_mask = []
|
||||||
|
i = 0
|
||||||
|
while i < length:
|
||||||
|
while i < length and not mask[i]:
|
||||||
|
i = i + 1
|
||||||
|
start_idx = i
|
||||||
|
while i < length and mask[i]:
|
||||||
|
i = i + 1
|
||||||
|
end_idx = i
|
||||||
|
if start_idx != end_idx:
|
||||||
|
im_mask.append((start_idx, end_idx))
|
||||||
|
|
||||||
inputs = {
|
inputs = {
|
||||||
k: v.to(device=self.device, dtype=self.dtype)
|
k: v.to(device=self.device, dtype=self.dtype)
|
||||||
for k, v in inputs.items() if torch.is_tensor(v)
|
for k, v in inputs.items() if torch.is_tensor(v)
|
||||||
}
|
}
|
||||||
im_mask = im_mask.to(self.device)
|
|
||||||
# ipex-llm changes end
|
# ipex-llm changes end
|
||||||
|
|
||||||
# also add end-of-assistant token in eos token id to avoid unnecessary generation
|
# also add end-of-assistant token in eos token id to avoid unnecessary generation
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue