support internvl2-4b (#11718)
This commit is contained in:
		
							parent
							
								
									7f241133da
								
							
						
					
					
						commit
						f44b732aa8
					
				
					 2 changed files with 177 additions and 0 deletions
				
			
		| 
						 | 
					@ -1256,6 +1256,18 @@ def _optimize_post(model, lightweight_bmm=False):
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        convert_forward(model, module.InternLM2Model, internlm_xcomposser2_model_forward)
 | 
					        convert_forward(model, module.InternLM2Model, internlm_xcomposser2_model_forward)
 | 
				
			||||||
        model.chat = MethodType(internlm_xcomposser2_chat, model)
 | 
					        model.chat = MethodType(internlm_xcomposser2_chat, model)
 | 
				
			||||||
 | 
					    elif model.config.model_type == "internvl_chat":
 | 
				
			||||||
 | 
					        modeling_module_name = model.__class__.__module__
 | 
				
			||||||
 | 
					        module = importlib.import_module(modeling_module_name)
 | 
				
			||||||
 | 
					        from ipex_llm.transformers.models.internvl import internvl_chat
 | 
				
			||||||
 | 
					        from ipex_llm.transformers.models.internvl import internvl_batch_chat
 | 
				
			||||||
 | 
					        model.get_conv_template = module.get_conv_template
 | 
				
			||||||
 | 
					        model.chat = MethodType(internvl_chat, model)
 | 
				
			||||||
 | 
					        model.batch_chat = MethodType(internvl_batch_chat, model)
 | 
				
			||||||
 | 
					        if model.vision_model.__class__.__name__ == "InternVisionModel":
 | 
				
			||||||
 | 
					            from ipex_llm.transformers.models.internvl import _get_pos_embed
 | 
				
			||||||
 | 
					            vision_embedding = model.vision_model.embeddings
 | 
				
			||||||
 | 
					            vision_embedding._get_pos_embed = MethodType(_get_pos_embed, vision_embedding)
 | 
				
			||||||
    elif model.config.model_type == "qwen":
 | 
					    elif model.config.model_type == "qwen":
 | 
				
			||||||
        if hasattr(model.config, "visual"):
 | 
					        if hasattr(model.config, "visual"):
 | 
				
			||||||
            # for Qwen-VL-Chat
 | 
					            # for Qwen-VL-Chat
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										165
									
								
								python/llm/src/ipex_llm/transformers/models/internvl.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										165
									
								
								python/llm/src/ipex_llm/transformers/models/internvl.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,165 @@
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Copyright 2016 The BigDL Authors.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Some parts of this file is adapted from
 | 
				
			||||||
 | 
					# https://huggingface.co/OpenGVLab/InternVL2-4B/blob/main/modeling_intern_vit.py
 | 
				
			||||||
 | 
					# which is licensed under MIT:
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# --------------------------------------------------------
 | 
				
			||||||
 | 
					# InternVL
 | 
				
			||||||
 | 
					# Copyright (c) 2024 OpenGVLab
 | 
				
			||||||
 | 
					# Licensed under The MIT License [see LICENSE for details]
 | 
				
			||||||
 | 
					# --------------------------------------------------------
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					from ipex_llm.utils.common.log4Error import invalidInputError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _get_pos_embed(self, pos_embed, H, W):
 | 
				
			||||||
 | 
					    target_dtype = pos_embed.dtype
 | 
				
			||||||
 | 
					    device = pos_embed.device
 | 
				
			||||||
 | 
					    pos_embed = pos_embed.float().reshape(
 | 
				
			||||||
 | 
					        1, self.image_size // self.patch_size, self.image_size // self.patch_size, -1
 | 
				
			||||||
 | 
					    ).permute(0, 3, 1, 2)
 | 
				
			||||||
 | 
					    # ipex-llm change start: call interpolate on CPU to fix bug
 | 
				
			||||||
 | 
					    pos_embed = torch.nn.functional.interpolate(
 | 
				
			||||||
 | 
					        pos_embed.to('cpu'), size=(H, W), mode='bicubic', align_corners=False
 | 
				
			||||||
 | 
					    ).reshape(1, -1, H * W).permute(0, 2, 1).to(target_dtype).to(device)
 | 
				
			||||||
 | 
					    # ipex-llm changes end
 | 
				
			||||||
 | 
					    return pos_embed
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def internvl_chat(self, tokenizer, pixel_values, question, generation_config,
 | 
				
			||||||
 | 
					                  history=None, return_history=False, num_patches_list=None,
 | 
				
			||||||
 | 
					                  IMG_START_TOKEN='<img>', IMG_END_TOKEN='</img>',
 | 
				
			||||||
 | 
					                  IMG_CONTEXT_TOKEN='<IMG_CONTEXT>', verbose=False):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if history is None and pixel_values is not None and '<image>' not in question:
 | 
				
			||||||
 | 
					        question = '<image>\n' + question
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if num_patches_list is None:
 | 
				
			||||||
 | 
					        num_patches_list = [pixel_values.shape[0]] if pixel_values is not None else []
 | 
				
			||||||
 | 
					    invalidInputError(pixel_values is None or len(pixel_values) == sum(num_patches_list),
 | 
				
			||||||
 | 
					                      "wrong num_patches_list length")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
 | 
				
			||||||
 | 
					    self.img_context_token_id = img_context_token_id
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    template = self.get_conv_template(self.template)
 | 
				
			||||||
 | 
					    template.system_message = self.system_message
 | 
				
			||||||
 | 
					    eos_token_id = tokenizer.convert_tokens_to_ids(template.sep)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    history = [] if history is None else history
 | 
				
			||||||
 | 
					    for (old_question, old_answer) in history:
 | 
				
			||||||
 | 
					        template.append_message(template.roles[0], old_question)
 | 
				
			||||||
 | 
					        template.append_message(template.roles[1], old_answer)
 | 
				
			||||||
 | 
					    template.append_message(template.roles[0], question)
 | 
				
			||||||
 | 
					    template.append_message(template.roles[1], None)
 | 
				
			||||||
 | 
					    query = template.get_prompt()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if verbose and pixel_values is not None:
 | 
				
			||||||
 | 
					        image_bs = pixel_values.shape[0]
 | 
				
			||||||
 | 
					        print(f'dynamic ViT batch size: {image_bs}')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    for num_patches in num_patches_list:
 | 
				
			||||||
 | 
					        image_tokens = (IMG_START_TOKEN
 | 
				
			||||||
 | 
					                        + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches
 | 
				
			||||||
 | 
					                        + IMG_END_TOKEN)
 | 
				
			||||||
 | 
					        query = query.replace('<image>', image_tokens, 1)
 | 
				
			||||||
 | 
					    model_inputs = tokenizer(query, return_tensors='pt')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # ipex-llm changes start: move input_ids and attention_mask to xpu
 | 
				
			||||||
 | 
					    input_ids = model_inputs['input_ids'].to(self.device)
 | 
				
			||||||
 | 
					    attention_mask = model_inputs['attention_mask'].to(self.device)
 | 
				
			||||||
 | 
					    if pixel_values is not None:
 | 
				
			||||||
 | 
					        pixel_values = pixel_values.to(dtype=self.dtype, device=self.device)
 | 
				
			||||||
 | 
					    # ipex-llm changes end
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    generation_config['eos_token_id'] = eos_token_id
 | 
				
			||||||
 | 
					    generation_output = self.generate(
 | 
				
			||||||
 | 
					        pixel_values=pixel_values,
 | 
				
			||||||
 | 
					        input_ids=input_ids,
 | 
				
			||||||
 | 
					        attention_mask=attention_mask,
 | 
				
			||||||
 | 
					        **generation_config
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    response = tokenizer.batch_decode(generation_output, skip_special_tokens=True)[0]
 | 
				
			||||||
 | 
					    response = response.split(template.sep)[0].strip()
 | 
				
			||||||
 | 
					    history.append((question, response))
 | 
				
			||||||
 | 
					    if return_history:
 | 
				
			||||||
 | 
					        return response, history
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        query_to_print = query.replace(IMG_CONTEXT_TOKEN, '')
 | 
				
			||||||
 | 
					        query_to_print = query_to_print.replace(f'{IMG_START_TOKEN}{IMG_END_TOKEN}', '<image>')
 | 
				
			||||||
 | 
					        if verbose:
 | 
				
			||||||
 | 
					            print(query_to_print, response)
 | 
				
			||||||
 | 
					        return response
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def internvl_batch_chat(self, tokenizer, pixel_values, questions, generation_config,
 | 
				
			||||||
 | 
					                        num_patches_list=None, history=None, return_history=False,
 | 
				
			||||||
 | 
					                        IMG_START_TOKEN='<img>', IMG_END_TOKEN='</img>',
 | 
				
			||||||
 | 
					                        IMG_CONTEXT_TOKEN='<IMG_CONTEXT>', verbose=False, image_counts=None):
 | 
				
			||||||
 | 
					    invalidInputError(history is None and not return_history,
 | 
				
			||||||
 | 
					                      'Now multi-turn chat is not supported in batch_chat.')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if image_counts is not None:
 | 
				
			||||||
 | 
					        num_patches_list = image_counts
 | 
				
			||||||
 | 
					        print('Warning: `image_counts` is deprecated. Please use `num_patches_list` instead.')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
 | 
				
			||||||
 | 
					    self.img_context_token_id = img_context_token_id
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if verbose and pixel_values is not None:
 | 
				
			||||||
 | 
					        image_bs = pixel_values.shape[0]
 | 
				
			||||||
 | 
					        print(f'dynamic ViT batch size: {image_bs}')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    queries = []
 | 
				
			||||||
 | 
					    for idx, num_patches in enumerate(num_patches_list):
 | 
				
			||||||
 | 
					        question = questions[idx]
 | 
				
			||||||
 | 
					        if pixel_values is not None and '<image>' not in question:
 | 
				
			||||||
 | 
					            question = '<image>\n' + question
 | 
				
			||||||
 | 
					        template = self.get_conv_template(self.template)
 | 
				
			||||||
 | 
					        template.append_message(template.roles[0], question)
 | 
				
			||||||
 | 
					        template.append_message(template.roles[1], None)
 | 
				
			||||||
 | 
					        query = template.get_prompt()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        image_tokens = (IMG_START_TOKEN
 | 
				
			||||||
 | 
					                        + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches
 | 
				
			||||||
 | 
					                        + IMG_END_TOKEN)
 | 
				
			||||||
 | 
					        query = query.replace('<image>', image_tokens, 1)
 | 
				
			||||||
 | 
					        queries.append(query)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    tokenizer.padding_side = 'left'
 | 
				
			||||||
 | 
					    model_inputs = tokenizer(queries, return_tensors='pt', padding=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # ipex-llm changes start: move input_ids and attention_mask to xpu
 | 
				
			||||||
 | 
					    input_ids = model_inputs['input_ids'].to(self.device)
 | 
				
			||||||
 | 
					    attention_mask = model_inputs['attention_mask'].to(self.device)
 | 
				
			||||||
 | 
					    if pixel_values is not None:
 | 
				
			||||||
 | 
					        pixel_values = pixel_values.to(dtype=self.dtype, device=self.device)
 | 
				
			||||||
 | 
					    # ipex-llm changes end
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    eos_token_id = tokenizer.convert_tokens_to_ids(template.sep)
 | 
				
			||||||
 | 
					    generation_config['eos_token_id'] = eos_token_id
 | 
				
			||||||
 | 
					    generation_output = self.generate(
 | 
				
			||||||
 | 
					        pixel_values=pixel_values,
 | 
				
			||||||
 | 
					        input_ids=input_ids,
 | 
				
			||||||
 | 
					        attention_mask=attention_mask,
 | 
				
			||||||
 | 
					        **generation_config
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    responses = tokenizer.batch_decode(generation_output, skip_special_tokens=True)
 | 
				
			||||||
 | 
					    responses = [response.split(template.sep)[0].strip() for response in responses]
 | 
				
			||||||
 | 
					    return responses
 | 
				
			||||||
		Loading…
	
		Reference in a new issue