fix internlm xcomposser stream chat (#11564)
This commit is contained in:
parent
b9c66994a5
commit
a945500a98
2 changed files with 74 additions and 23 deletions
|
|
@ -1259,12 +1259,19 @@ def _optimize_post(model, lightweight_bmm=False):
|
||||||
elif model.config.model_type == "internlmxcomposer2":
|
elif model.config.model_type == "internlmxcomposer2":
|
||||||
modeling_module_name = model.model.__class__.__module__
|
modeling_module_name = model.model.__class__.__module__
|
||||||
module = importlib.import_module(modeling_module_name)
|
module = importlib.import_module(modeling_module_name)
|
||||||
from ipex_llm.transformers.models.internlm import internlm_xcomposser2_attention_forward
|
from ipex_llm.transformers.models.internlm import (
|
||||||
|
internlm_xcomposser2_attention_forward,
|
||||||
|
internlm_xcomposser2_mlp_forward,
|
||||||
|
internlm_xcomposser2_model_forward_wrapper,
|
||||||
|
internlm_xcomposser2_chat
|
||||||
|
)
|
||||||
convert_forward(model, module.InternLM2Attention, internlm_xcomposser2_attention_forward)
|
convert_forward(model, module.InternLM2Attention, internlm_xcomposser2_attention_forward)
|
||||||
from ipex_llm.transformers.models.internlm import internlm_xcomposser2_mlp_forward
|
|
||||||
convert_forward(model, module.InternLM2MLP, internlm_xcomposser2_mlp_forward)
|
convert_forward(model, module.InternLM2MLP, internlm_xcomposser2_mlp_forward)
|
||||||
convert_forward(model, module.InternLM2RMSNorm, llama_rms_norm_forward)
|
convert_forward(model, module.InternLM2RMSNorm, llama_rms_norm_forward)
|
||||||
from ipex_llm.transformers.models.internlm import internlm_xcomposser2_chat
|
internlm_xcomposser2_model_forward = internlm_xcomposser2_model_forward_wrapper(
|
||||||
|
module.InternLM2Model.forward
|
||||||
|
)
|
||||||
|
convert_forward(model, module.InternLM2Model, internlm_xcomposser2_model_forward)
|
||||||
model.chat = MethodType(internlm_xcomposser2_chat, model)
|
model.chat = MethodType(internlm_xcomposser2_chat, model)
|
||||||
elif model.config.model_type == "qwen":
|
elif model.config.model_type == "qwen":
|
||||||
if hasattr(model.config, "visual"):
|
if hasattr(model.config, "visual"):
|
||||||
|
|
|
||||||
|
|
@ -310,7 +310,7 @@ def add_lora(x: torch.Tensor, result: torch.Tensor,
|
||||||
Plora_A: torch.nn.Linear = None, Plora_B: torch.nn.Linear = None):
|
Plora_A: torch.nn.Linear = None, Plora_B: torch.nn.Linear = None):
|
||||||
invalidInputError(x.dim() == 3 and result.dim() == 3,
|
invalidInputError(x.dim() == 3 and result.dim() == 3,
|
||||||
"`x` and `result` should have 3 dims")
|
"`x` and `result` should have 3 dims")
|
||||||
if len(im_mask) == 0 or x.size(1) == 1:
|
if isinstance(im_mask, torch.Tensor) or len(im_mask) == 0:
|
||||||
return result
|
return result
|
||||||
else:
|
else:
|
||||||
for start_idx, end_idx in im_mask:
|
for start_idx, end_idx in im_mask:
|
||||||
|
|
@ -320,6 +320,56 @@ def add_lora(x: torch.Tensor, result: torch.Tensor,
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def internlm_xcomposser2_model_forward_wrapper(origin_forward):
|
||||||
|
def internlm_xcomposser2_model_forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
im_mask = kwargs.get('im_mask', None)
|
||||||
|
if im_mask is None or im_mask.size(-1) <= 1 or im_mask.sum() == 0:
|
||||||
|
# decoding or no image input, `im_mask` is not needed
|
||||||
|
kwargs['im_mask'] = []
|
||||||
|
else:
|
||||||
|
# replace im_mask with start_idx and end_idx to improve performance
|
||||||
|
im_mask = im_mask.cpu().flatten().tolist()
|
||||||
|
length = len(im_mask)
|
||||||
|
new_mask = []
|
||||||
|
i = 0
|
||||||
|
while i < length:
|
||||||
|
while i < length and not im_mask[i]:
|
||||||
|
i = i + 1
|
||||||
|
start_idx = i
|
||||||
|
while i < length and im_mask[i]:
|
||||||
|
i = i + 1
|
||||||
|
end_idx = i
|
||||||
|
if start_idx != end_idx:
|
||||||
|
new_mask.append((start_idx, end_idx))
|
||||||
|
kwargs['im_mask'] = new_mask
|
||||||
|
return origin_forward(
|
||||||
|
self=self,
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
return internlm_xcomposser2_model_forward
|
||||||
|
|
||||||
|
|
||||||
def internlm_xcomposser2_attention_forward(
|
def internlm_xcomposser2_attention_forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
|
@ -466,32 +516,26 @@ def internlm_xcomposser2_chat(
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
# ipex-llm changes start: fix device and dtype conversion
|
# ipex-llm changes start: fix device and dtype conversion
|
||||||
# replace im_mask with start_idx and end_idx to improve performance
|
|
||||||
if image is None:
|
if image is None:
|
||||||
inputs = self.build_inputs(tokenizer, query, history, meta_instruction)
|
inputs = self.build_inputs(tokenizer, query, history, meta_instruction)
|
||||||
im_mask = []
|
im_mask = torch.zeros(inputs['input_ids'].shape[:2]).bool()
|
||||||
else:
|
else:
|
||||||
image = self.encode_img(image)
|
image = self.encode_img(image)
|
||||||
inputs, im_mask = self.interleav_wrap_chat(tokenizer, query, image,
|
inputs, im_mask = self.interleav_wrap_chat(tokenizer, query, image,
|
||||||
history, meta_instruction)
|
history, meta_instruction)
|
||||||
mask = im_mask.cpu().flatten().tolist()
|
|
||||||
length = len(mask)
|
|
||||||
im_mask = []
|
|
||||||
i = 0
|
|
||||||
while i < length:
|
|
||||||
while i < length and not mask[i]:
|
|
||||||
i = i + 1
|
|
||||||
start_idx = i
|
|
||||||
while i < length and mask[i]:
|
|
||||||
i = i + 1
|
|
||||||
end_idx = i
|
|
||||||
if start_idx != end_idx:
|
|
||||||
im_mask.append((start_idx, end_idx))
|
|
||||||
|
|
||||||
inputs = {
|
new_inputs = {}
|
||||||
k: v.to(device=self.device, dtype=self.dtype)
|
for k, v in inputs.items():
|
||||||
for k, v in inputs.items() if torch.is_tensor(v)
|
if torch.is_tensor(v):
|
||||||
}
|
if v.dtype.is_floating_point:
|
||||||
|
new_inputs[k] = v.to(device=self.device, dtype=self.dtype)
|
||||||
|
else:
|
||||||
|
# input_ids, don't convert its dtype
|
||||||
|
new_inputs[k] = v.to(device=self.device)
|
||||||
|
else:
|
||||||
|
new_inputs[k] = v
|
||||||
|
inputs = new_inputs
|
||||||
|
im_mask = im_mask.to(self.device)
|
||||||
# ipex-llm changes end
|
# ipex-llm changes end
|
||||||
|
|
||||||
# also add end-of-assistant token in eos token id to avoid unnecessary generation
|
# also add end-of-assistant token in eos token id to avoid unnecessary generation
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue