support and optimize minicpm-v-2_6 (#11738)
This commit is contained in:
parent
e956e71fc1
commit
54cc9353db
2 changed files with 65 additions and 1 deletions
|
|
@ -740,13 +740,18 @@ def _optimize_pre(model, qtype=None):
|
||||||
from ipex_llm.transformers.models.internlm import pre_process_attn_and_mlp
|
from ipex_llm.transformers.models.internlm import pre_process_attn_and_mlp
|
||||||
model.apply(pre_process_attn_and_mlp)
|
model.apply(pre_process_attn_and_mlp)
|
||||||
if model.config.model_type == "internvl_chat":
|
if model.config.model_type == "internvl_chat":
|
||||||
_optimize_pre(model.language_model)
|
_optimize_pre(model.language_model, qtype=qtype)
|
||||||
if model.config.model_type == "gemma2":
|
if model.config.model_type == "gemma2":
|
||||||
from ipex_llm.transformers.models.gemma2 import merge_qkv
|
from ipex_llm.transformers.models.gemma2 import merge_qkv
|
||||||
model.apply(merge_qkv)
|
model.apply(merge_qkv)
|
||||||
if model.config.model_type == "llama":
|
if model.config.model_type == "llama":
|
||||||
from ipex_llm.transformers.models.llama import merge_qkv
|
from ipex_llm.transformers.models.llama import merge_qkv
|
||||||
model.apply(merge_qkv)
|
model.apply(merge_qkv)
|
||||||
|
if model.config.model_type == "minicpmv":
|
||||||
|
if model.config.hidden_size == 3584 and model.config.vocab_size == 151666:
|
||||||
|
model.llm.config.model_type = "qwen2"
|
||||||
|
_optimize_pre(model.llm, qtype=qtype)
|
||||||
|
model.llm.config.model_type = "minicpmv"
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
@ -1747,5 +1752,15 @@ def _optimize_post(model, lightweight_bmm=False):
|
||||||
convert_forward(model,
|
convert_forward(model,
|
||||||
module.MiniCPMModel,
|
module.MiniCPMModel,
|
||||||
minicpm_model_forward)
|
minicpm_model_forward)
|
||||||
|
elif model.config.model_type == "minicpmv":
|
||||||
|
if model.config.hidden_size == 3584 and model.config.vocab_size == 151666:
|
||||||
|
model.llm.config.model_type = "qwen2"
|
||||||
|
_optimize_post(model.llm, lightweight_bmm=lightweight_bmm)
|
||||||
|
model.llm.config.model_type = "minicpmv"
|
||||||
|
modeling_module_name = model.__class__.__module__
|
||||||
|
module = importlib.import_module(modeling_module_name)
|
||||||
|
from ipex_llm.transformers.models.minicpmv import minicpmv_generate_wrapper
|
||||||
|
minicpmv_generate = minicpmv_generate_wrapper(module.MiniCPMV.generate)
|
||||||
|
model.generate = MethodType(minicpmv_generate, model)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
|
||||||
49
python/llm/src/ipex_llm/transformers/models/minicpmv.py
Normal file
49
python/llm/src/ipex_llm/transformers/models/minicpmv.py
Normal file
|
|
@ -0,0 +1,49 @@
|
||||||
|
#
|
||||||
|
# Copyright 2016 The BigDL Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
|
||||||
|
def minicpmv_generate_wrapper(origin_generate):
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
input_ids=None,
|
||||||
|
pixel_values=None,
|
||||||
|
tgt_sizes=None,
|
||||||
|
image_bound=None,
|
||||||
|
attention_mask=None,
|
||||||
|
tokenizer=None,
|
||||||
|
vision_hidden_states=None,
|
||||||
|
return_vision_hidden_states=False,
|
||||||
|
stream=False,
|
||||||
|
decode_text=False,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
if kwargs.get("repetition_penalty", None) is not None:
|
||||||
|
kwargs["repetition_penalty"] = 1
|
||||||
|
return origin_generate(
|
||||||
|
self=self,
|
||||||
|
input_ids=input_ids,
|
||||||
|
pixel_values=pixel_values,
|
||||||
|
tgt_sizes=tgt_sizes,
|
||||||
|
image_bound=image_bound,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
vision_hidden_states=vision_hidden_states,
|
||||||
|
return_vision_hidden_states=return_vision_hidden_states,
|
||||||
|
stream=stream,
|
||||||
|
decode_text=decode_text,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
return generate
|
||||||
Loading…
Reference in a new issue