support internvl2-4b (#11718)
This commit is contained in:
		
							parent
							
								
									7f241133da
								
							
						
					
					
						commit
						f44b732aa8
					
				
					 2 changed files with 177 additions and 0 deletions
				
			
		| 
						 | 
				
			
			@ -1256,6 +1256,18 @@ def _optimize_post(model, lightweight_bmm=False):
 | 
			
		|||
        )
 | 
			
		||||
        convert_forward(model, module.InternLM2Model, internlm_xcomposser2_model_forward)
 | 
			
		||||
        model.chat = MethodType(internlm_xcomposser2_chat, model)
 | 
			
		||||
    elif model.config.model_type == "internvl_chat":
 | 
			
		||||
        modeling_module_name = model.__class__.__module__
 | 
			
		||||
        module = importlib.import_module(modeling_module_name)
 | 
			
		||||
        from ipex_llm.transformers.models.internvl import internvl_chat
 | 
			
		||||
        from ipex_llm.transformers.models.internvl import internvl_batch_chat
 | 
			
		||||
        model.get_conv_template = module.get_conv_template
 | 
			
		||||
        model.chat = MethodType(internvl_chat, model)
 | 
			
		||||
        model.batch_chat = MethodType(internvl_batch_chat, model)
 | 
			
		||||
        if model.vision_model.__class__.__name__ == "InternVisionModel":
 | 
			
		||||
            from ipex_llm.transformers.models.internvl import _get_pos_embed
 | 
			
		||||
            vision_embedding = model.vision_model.embeddings
 | 
			
		||||
            vision_embedding._get_pos_embed = MethodType(_get_pos_embed, vision_embedding)
 | 
			
		||||
    elif model.config.model_type == "qwen":
 | 
			
		||||
        if hasattr(model.config, "visual"):
 | 
			
		||||
            # for Qwen-VL-Chat
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										165
									
								
								python/llm/src/ipex_llm/transformers/models/internvl.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										165
									
								
								python/llm/src/ipex_llm/transformers/models/internvl.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,165 @@
 | 
			
		|||
#
 | 
			
		||||
# Copyright 2016 The BigDL Authors.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
#
 | 
			
		||||
# Some parts of this file is adapted from
 | 
			
		||||
# https://huggingface.co/OpenGVLab/InternVL2-4B/blob/main/modeling_intern_vit.py
 | 
			
		||||
# which is licensed under MIT:
 | 
			
		||||
#
 | 
			
		||||
# --------------------------------------------------------
 | 
			
		||||
# InternVL
 | 
			
		||||
# Copyright (c) 2024 OpenGVLab
 | 
			
		||||
# Licensed under The MIT License [see LICENSE for details]
 | 
			
		||||
# --------------------------------------------------------
 | 
			
		||||
#
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from ipex_llm.utils.common.log4Error import invalidInputError
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _get_pos_embed(self, pos_embed, H, W):
 | 
			
		||||
    target_dtype = pos_embed.dtype
 | 
			
		||||
    device = pos_embed.device
 | 
			
		||||
    pos_embed = pos_embed.float().reshape(
 | 
			
		||||
        1, self.image_size // self.patch_size, self.image_size // self.patch_size, -1
 | 
			
		||||
    ).permute(0, 3, 1, 2)
 | 
			
		||||
    # ipex-llm change start: call interpolate on CPU to fix bug
 | 
			
		||||
    pos_embed = torch.nn.functional.interpolate(
 | 
			
		||||
        pos_embed.to('cpu'), size=(H, W), mode='bicubic', align_corners=False
 | 
			
		||||
    ).reshape(1, -1, H * W).permute(0, 2, 1).to(target_dtype).to(device)
 | 
			
		||||
    # ipex-llm changes end
 | 
			
		||||
    return pos_embed
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def internvl_chat(self, tokenizer, pixel_values, question, generation_config,
 | 
			
		||||
                  history=None, return_history=False, num_patches_list=None,
 | 
			
		||||
                  IMG_START_TOKEN='<img>', IMG_END_TOKEN='</img>',
 | 
			
		||||
                  IMG_CONTEXT_TOKEN='<IMG_CONTEXT>', verbose=False):
 | 
			
		||||
 | 
			
		||||
    if history is None and pixel_values is not None and '<image>' not in question:
 | 
			
		||||
        question = '<image>\n' + question
 | 
			
		||||
 | 
			
		||||
    if num_patches_list is None:
 | 
			
		||||
        num_patches_list = [pixel_values.shape[0]] if pixel_values is not None else []
 | 
			
		||||
    invalidInputError(pixel_values is None or len(pixel_values) == sum(num_patches_list),
 | 
			
		||||
                      "wrong num_patches_list length")
 | 
			
		||||
 | 
			
		||||
    img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
 | 
			
		||||
    self.img_context_token_id = img_context_token_id
 | 
			
		||||
 | 
			
		||||
    template = self.get_conv_template(self.template)
 | 
			
		||||
    template.system_message = self.system_message
 | 
			
		||||
    eos_token_id = tokenizer.convert_tokens_to_ids(template.sep)
 | 
			
		||||
 | 
			
		||||
    history = [] if history is None else history
 | 
			
		||||
    for (old_question, old_answer) in history:
 | 
			
		||||
        template.append_message(template.roles[0], old_question)
 | 
			
		||||
        template.append_message(template.roles[1], old_answer)
 | 
			
		||||
    template.append_message(template.roles[0], question)
 | 
			
		||||
    template.append_message(template.roles[1], None)
 | 
			
		||||
    query = template.get_prompt()
 | 
			
		||||
 | 
			
		||||
    if verbose and pixel_values is not None:
 | 
			
		||||
        image_bs = pixel_values.shape[0]
 | 
			
		||||
        print(f'dynamic ViT batch size: {image_bs}')
 | 
			
		||||
 | 
			
		||||
    for num_patches in num_patches_list:
 | 
			
		||||
        image_tokens = (IMG_START_TOKEN
 | 
			
		||||
                        + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches
 | 
			
		||||
                        + IMG_END_TOKEN)
 | 
			
		||||
        query = query.replace('<image>', image_tokens, 1)
 | 
			
		||||
    model_inputs = tokenizer(query, return_tensors='pt')
 | 
			
		||||
 | 
			
		||||
    # ipex-llm changes start: move input_ids and attention_mask to xpu
 | 
			
		||||
    input_ids = model_inputs['input_ids'].to(self.device)
 | 
			
		||||
    attention_mask = model_inputs['attention_mask'].to(self.device)
 | 
			
		||||
    if pixel_values is not None:
 | 
			
		||||
        pixel_values = pixel_values.to(dtype=self.dtype, device=self.device)
 | 
			
		||||
    # ipex-llm changes end
 | 
			
		||||
 | 
			
		||||
    generation_config['eos_token_id'] = eos_token_id
 | 
			
		||||
    generation_output = self.generate(
 | 
			
		||||
        pixel_values=pixel_values,
 | 
			
		||||
        input_ids=input_ids,
 | 
			
		||||
        attention_mask=attention_mask,
 | 
			
		||||
        **generation_config
 | 
			
		||||
    )
 | 
			
		||||
    response = tokenizer.batch_decode(generation_output, skip_special_tokens=True)[0]
 | 
			
		||||
    response = response.split(template.sep)[0].strip()
 | 
			
		||||
    history.append((question, response))
 | 
			
		||||
    if return_history:
 | 
			
		||||
        return response, history
 | 
			
		||||
    else:
 | 
			
		||||
        query_to_print = query.replace(IMG_CONTEXT_TOKEN, '')
 | 
			
		||||
        query_to_print = query_to_print.replace(f'{IMG_START_TOKEN}{IMG_END_TOKEN}', '<image>')
 | 
			
		||||
        if verbose:
 | 
			
		||||
            print(query_to_print, response)
 | 
			
		||||
        return response
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def internvl_batch_chat(self, tokenizer, pixel_values, questions, generation_config,
 | 
			
		||||
                        num_patches_list=None, history=None, return_history=False,
 | 
			
		||||
                        IMG_START_TOKEN='<img>', IMG_END_TOKEN='</img>',
 | 
			
		||||
                        IMG_CONTEXT_TOKEN='<IMG_CONTEXT>', verbose=False, image_counts=None):
 | 
			
		||||
    invalidInputError(history is None and not return_history,
 | 
			
		||||
                      'Now multi-turn chat is not supported in batch_chat.')
 | 
			
		||||
 | 
			
		||||
    if image_counts is not None:
 | 
			
		||||
        num_patches_list = image_counts
 | 
			
		||||
        print('Warning: `image_counts` is deprecated. Please use `num_patches_list` instead.')
 | 
			
		||||
 | 
			
		||||
    img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
 | 
			
		||||
    self.img_context_token_id = img_context_token_id
 | 
			
		||||
 | 
			
		||||
    if verbose and pixel_values is not None:
 | 
			
		||||
        image_bs = pixel_values.shape[0]
 | 
			
		||||
        print(f'dynamic ViT batch size: {image_bs}')
 | 
			
		||||
 | 
			
		||||
    queries = []
 | 
			
		||||
    for idx, num_patches in enumerate(num_patches_list):
 | 
			
		||||
        question = questions[idx]
 | 
			
		||||
        if pixel_values is not None and '<image>' not in question:
 | 
			
		||||
            question = '<image>\n' + question
 | 
			
		||||
        template = self.get_conv_template(self.template)
 | 
			
		||||
        template.append_message(template.roles[0], question)
 | 
			
		||||
        template.append_message(template.roles[1], None)
 | 
			
		||||
        query = template.get_prompt()
 | 
			
		||||
 | 
			
		||||
        image_tokens = (IMG_START_TOKEN
 | 
			
		||||
                        + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches
 | 
			
		||||
                        + IMG_END_TOKEN)
 | 
			
		||||
        query = query.replace('<image>', image_tokens, 1)
 | 
			
		||||
        queries.append(query)
 | 
			
		||||
 | 
			
		||||
    tokenizer.padding_side = 'left'
 | 
			
		||||
    model_inputs = tokenizer(queries, return_tensors='pt', padding=True)
 | 
			
		||||
 | 
			
		||||
    # ipex-llm changes start: move input_ids and attention_mask to xpu
 | 
			
		||||
    input_ids = model_inputs['input_ids'].to(self.device)
 | 
			
		||||
    attention_mask = model_inputs['attention_mask'].to(self.device)
 | 
			
		||||
    if pixel_values is not None:
 | 
			
		||||
        pixel_values = pixel_values.to(dtype=self.dtype, device=self.device)
 | 
			
		||||
    # ipex-llm changes end
 | 
			
		||||
 | 
			
		||||
    eos_token_id = tokenizer.convert_tokens_to_ids(template.sep)
 | 
			
		||||
    generation_config['eos_token_id'] = eos_token_id
 | 
			
		||||
    generation_output = self.generate(
 | 
			
		||||
        pixel_values=pixel_values,
 | 
			
		||||
        input_ids=input_ids,
 | 
			
		||||
        attention_mask=attention_mask,
 | 
			
		||||
        **generation_config
 | 
			
		||||
    )
 | 
			
		||||
    responses = tokenizer.batch_decode(generation_output, skip_special_tokens=True)
 | 
			
		||||
    responses = [response.split(template.sep)[0].strip() for response in responses]
 | 
			
		||||
    return responses
 | 
			
		||||
		Loading…
	
		Reference in a new issue