support internvl2-4b (#11718)
This commit is contained in:
parent
7f241133da
commit
f44b732aa8
2 changed files with 177 additions and 0 deletions
|
|
@ -1256,6 +1256,18 @@ def _optimize_post(model, lightweight_bmm=False):
|
||||||
)
|
)
|
||||||
convert_forward(model, module.InternLM2Model, internlm_xcomposser2_model_forward)
|
convert_forward(model, module.InternLM2Model, internlm_xcomposser2_model_forward)
|
||||||
model.chat = MethodType(internlm_xcomposser2_chat, model)
|
model.chat = MethodType(internlm_xcomposser2_chat, model)
|
||||||
|
elif model.config.model_type == "internvl_chat":
|
||||||
|
modeling_module_name = model.__class__.__module__
|
||||||
|
module = importlib.import_module(modeling_module_name)
|
||||||
|
from ipex_llm.transformers.models.internvl import internvl_chat
|
||||||
|
from ipex_llm.transformers.models.internvl import internvl_batch_chat
|
||||||
|
model.get_conv_template = module.get_conv_template
|
||||||
|
model.chat = MethodType(internvl_chat, model)
|
||||||
|
model.batch_chat = MethodType(internvl_batch_chat, model)
|
||||||
|
if model.vision_model.__class__.__name__ == "InternVisionModel":
|
||||||
|
from ipex_llm.transformers.models.internvl import _get_pos_embed
|
||||||
|
vision_embedding = model.vision_model.embeddings
|
||||||
|
vision_embedding._get_pos_embed = MethodType(_get_pos_embed, vision_embedding)
|
||||||
elif model.config.model_type == "qwen":
|
elif model.config.model_type == "qwen":
|
||||||
if hasattr(model.config, "visual"):
|
if hasattr(model.config, "visual"):
|
||||||
# for Qwen-VL-Chat
|
# for Qwen-VL-Chat
|
||||||
|
|
|
||||||
165
python/llm/src/ipex_llm/transformers/models/internvl.py
Normal file
165
python/llm/src/ipex_llm/transformers/models/internvl.py
Normal file
|
|
@ -0,0 +1,165 @@
|
||||||
|
#
|
||||||
|
# Copyright 2016 The BigDL Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
# Some parts of this file is adapted from
|
||||||
|
# https://huggingface.co/OpenGVLab/InternVL2-4B/blob/main/modeling_intern_vit.py
|
||||||
|
# which is licensed under MIT:
|
||||||
|
#
|
||||||
|
# --------------------------------------------------------
|
||||||
|
# InternVL
|
||||||
|
# Copyright (c) 2024 OpenGVLab
|
||||||
|
# Licensed under The MIT License [see LICENSE for details]
|
||||||
|
# --------------------------------------------------------
|
||||||
|
#
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from ipex_llm.utils.common.log4Error import invalidInputError
|
||||||
|
|
||||||
|
|
||||||
|
def _get_pos_embed(self, pos_embed, H, W):
|
||||||
|
target_dtype = pos_embed.dtype
|
||||||
|
device = pos_embed.device
|
||||||
|
pos_embed = pos_embed.float().reshape(
|
||||||
|
1, self.image_size // self.patch_size, self.image_size // self.patch_size, -1
|
||||||
|
).permute(0, 3, 1, 2)
|
||||||
|
# ipex-llm change start: call interpolate on CPU to fix bug
|
||||||
|
pos_embed = torch.nn.functional.interpolate(
|
||||||
|
pos_embed.to('cpu'), size=(H, W), mode='bicubic', align_corners=False
|
||||||
|
).reshape(1, -1, H * W).permute(0, 2, 1).to(target_dtype).to(device)
|
||||||
|
# ipex-llm changes end
|
||||||
|
return pos_embed
|
||||||
|
|
||||||
|
|
||||||
|
def internvl_chat(self, tokenizer, pixel_values, question, generation_config,
|
||||||
|
history=None, return_history=False, num_patches_list=None,
|
||||||
|
IMG_START_TOKEN='<img>', IMG_END_TOKEN='</img>',
|
||||||
|
IMG_CONTEXT_TOKEN='<IMG_CONTEXT>', verbose=False):
|
||||||
|
|
||||||
|
if history is None and pixel_values is not None and '<image>' not in question:
|
||||||
|
question = '<image>\n' + question
|
||||||
|
|
||||||
|
if num_patches_list is None:
|
||||||
|
num_patches_list = [pixel_values.shape[0]] if pixel_values is not None else []
|
||||||
|
invalidInputError(pixel_values is None or len(pixel_values) == sum(num_patches_list),
|
||||||
|
"wrong num_patches_list length")
|
||||||
|
|
||||||
|
img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
|
||||||
|
self.img_context_token_id = img_context_token_id
|
||||||
|
|
||||||
|
template = self.get_conv_template(self.template)
|
||||||
|
template.system_message = self.system_message
|
||||||
|
eos_token_id = tokenizer.convert_tokens_to_ids(template.sep)
|
||||||
|
|
||||||
|
history = [] if history is None else history
|
||||||
|
for (old_question, old_answer) in history:
|
||||||
|
template.append_message(template.roles[0], old_question)
|
||||||
|
template.append_message(template.roles[1], old_answer)
|
||||||
|
template.append_message(template.roles[0], question)
|
||||||
|
template.append_message(template.roles[1], None)
|
||||||
|
query = template.get_prompt()
|
||||||
|
|
||||||
|
if verbose and pixel_values is not None:
|
||||||
|
image_bs = pixel_values.shape[0]
|
||||||
|
print(f'dynamic ViT batch size: {image_bs}')
|
||||||
|
|
||||||
|
for num_patches in num_patches_list:
|
||||||
|
image_tokens = (IMG_START_TOKEN
|
||||||
|
+ IMG_CONTEXT_TOKEN * self.num_image_token * num_patches
|
||||||
|
+ IMG_END_TOKEN)
|
||||||
|
query = query.replace('<image>', image_tokens, 1)
|
||||||
|
model_inputs = tokenizer(query, return_tensors='pt')
|
||||||
|
|
||||||
|
# ipex-llm changes start: move input_ids and attention_mask to xpu
|
||||||
|
input_ids = model_inputs['input_ids'].to(self.device)
|
||||||
|
attention_mask = model_inputs['attention_mask'].to(self.device)
|
||||||
|
if pixel_values is not None:
|
||||||
|
pixel_values = pixel_values.to(dtype=self.dtype, device=self.device)
|
||||||
|
# ipex-llm changes end
|
||||||
|
|
||||||
|
generation_config['eos_token_id'] = eos_token_id
|
||||||
|
generation_output = self.generate(
|
||||||
|
pixel_values=pixel_values,
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
**generation_config
|
||||||
|
)
|
||||||
|
response = tokenizer.batch_decode(generation_output, skip_special_tokens=True)[0]
|
||||||
|
response = response.split(template.sep)[0].strip()
|
||||||
|
history.append((question, response))
|
||||||
|
if return_history:
|
||||||
|
return response, history
|
||||||
|
else:
|
||||||
|
query_to_print = query.replace(IMG_CONTEXT_TOKEN, '')
|
||||||
|
query_to_print = query_to_print.replace(f'{IMG_START_TOKEN}{IMG_END_TOKEN}', '<image>')
|
||||||
|
if verbose:
|
||||||
|
print(query_to_print, response)
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
def internvl_batch_chat(self, tokenizer, pixel_values, questions, generation_config,
|
||||||
|
num_patches_list=None, history=None, return_history=False,
|
||||||
|
IMG_START_TOKEN='<img>', IMG_END_TOKEN='</img>',
|
||||||
|
IMG_CONTEXT_TOKEN='<IMG_CONTEXT>', verbose=False, image_counts=None):
|
||||||
|
invalidInputError(history is None and not return_history,
|
||||||
|
'Now multi-turn chat is not supported in batch_chat.')
|
||||||
|
|
||||||
|
if image_counts is not None:
|
||||||
|
num_patches_list = image_counts
|
||||||
|
print('Warning: `image_counts` is deprecated. Please use `num_patches_list` instead.')
|
||||||
|
|
||||||
|
img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
|
||||||
|
self.img_context_token_id = img_context_token_id
|
||||||
|
|
||||||
|
if verbose and pixel_values is not None:
|
||||||
|
image_bs = pixel_values.shape[0]
|
||||||
|
print(f'dynamic ViT batch size: {image_bs}')
|
||||||
|
|
||||||
|
queries = []
|
||||||
|
for idx, num_patches in enumerate(num_patches_list):
|
||||||
|
question = questions[idx]
|
||||||
|
if pixel_values is not None and '<image>' not in question:
|
||||||
|
question = '<image>\n' + question
|
||||||
|
template = self.get_conv_template(self.template)
|
||||||
|
template.append_message(template.roles[0], question)
|
||||||
|
template.append_message(template.roles[1], None)
|
||||||
|
query = template.get_prompt()
|
||||||
|
|
||||||
|
image_tokens = (IMG_START_TOKEN
|
||||||
|
+ IMG_CONTEXT_TOKEN * self.num_image_token * num_patches
|
||||||
|
+ IMG_END_TOKEN)
|
||||||
|
query = query.replace('<image>', image_tokens, 1)
|
||||||
|
queries.append(query)
|
||||||
|
|
||||||
|
tokenizer.padding_side = 'left'
|
||||||
|
model_inputs = tokenizer(queries, return_tensors='pt', padding=True)
|
||||||
|
|
||||||
|
# ipex-llm changes start: move input_ids and attention_mask to xpu
|
||||||
|
input_ids = model_inputs['input_ids'].to(self.device)
|
||||||
|
attention_mask = model_inputs['attention_mask'].to(self.device)
|
||||||
|
if pixel_values is not None:
|
||||||
|
pixel_values = pixel_values.to(dtype=self.dtype, device=self.device)
|
||||||
|
# ipex-llm changes end
|
||||||
|
|
||||||
|
eos_token_id = tokenizer.convert_tokens_to_ids(template.sep)
|
||||||
|
generation_config['eos_token_id'] = eos_token_id
|
||||||
|
generation_output = self.generate(
|
||||||
|
pixel_values=pixel_values,
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
**generation_config
|
||||||
|
)
|
||||||
|
responses = tokenizer.batch_decode(generation_output, skip_special_tokens=True)
|
||||||
|
responses = [response.split(template.sep)[0].strip() for response in responses]
|
||||||
|
return responses
|
||||||
Loading…
Reference in a new issue