diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index 22985743..a6b25a8f 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -706,7 +706,7 @@ def _optimize_pre(model): if model.config.model_type == "phi": from ipex_llm.transformers.models.phi import merge_qkv model.apply(merge_qkv) - if model.config.model_type == "phi3": + if model.config.model_type in ["phi3", "phi3_v"]: from ipex_llm.transformers.models.phi3 import pre_compute_inv_freq model.apply(pre_compute_inv_freq) from ipex_llm.transformers.models.phi3 import split_mlp @@ -1510,7 +1510,7 @@ def _optimize_post(model, lightweight_bmm=False): from ipex_llm.transformers.models.phi import model_forward convert_forward(model, module.PhiAttention, attention_forward) convert_forward(model, module.PhiModel, model_forward) - elif model.config.model_type == "phi3": + elif model.config.model_type in ["phi3", "phi3_v"]: # for phi-3 modeling_module_name = model.__class__.__module__ module = importlib.import_module(modeling_module_name) @@ -1518,11 +1518,16 @@ def _optimize_post(model, lightweight_bmm=False): convert_forward(model, module.Phi3Attention, attention_forward) from ipex_llm.transformers.models.phi3 import mlp_forward convert_forward(model, module.Phi3MLP, mlp_forward) - from ipex_llm.transformers.models.phi3 import model_forward_wrapper - model_forward = model_forward_wrapper(module.Phi3Model.forward) - convert_forward(model, module.Phi3Model, model_forward) from ipex_llm.transformers.models.phi3 import phi3_rms_norm_forward convert_forward(model, module.Phi3RMSNorm, phi3_rms_norm_forward) + if model.config.model_type == "phi3": + from ipex_llm.transformers.models.phi3 import phi3_model_forward_wrapper + model_forward = phi3_model_forward_wrapper(module.Phi3Model.forward) + convert_forward(model, module.Phi3Model, model_forward) + else: + from ipex_llm.transformers.models.phi3 import phi3v_model_forward_wrapper + model_forward = phi3v_model_forward_wrapper(module.Phi3VModel.forward) + convert_forward(model, module.Phi3VModel, model_forward) elif model.config.model_type == 'yuan': modeling_module_name = model.__class__.__module__ module = importlib.import_module(modeling_module_name) diff --git a/python/llm/src/ipex_llm/transformers/models/phi3.py b/python/llm/src/ipex_llm/transformers/models/phi3.py index adc61373..c1ea3ed2 100644 --- a/python/llm/src/ipex_llm/transformers/models/phi3.py +++ b/python/llm/src/ipex_llm/transformers/models/phi3.py @@ -215,7 +215,7 @@ def mlp_forward( ) -def model_forward_wrapper(origin_model_forward): +def phi3_model_forward_wrapper(origin_model_forward): def model_forward( self, input_ids: torch.LongTensor = None, @@ -251,6 +251,46 @@ def model_forward_wrapper(origin_model_forward): return model_forward +def phi3v_model_forward_wrapper(origin_model_forward): + def model_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + image_sizes: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + # IPEX-LLM OPT: kv cache and quantize kv cache and sdp + use_cache = use_cache if use_cache is not None else self.config.use_cache + use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, input_ids) + if use_cache: + if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache): + past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values) + if not use_quantize_kv and not isinstance(past_key_values, DynamicNormalCache): + past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values) + return origin_model_forward( + self=self, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + pixel_values=pixel_values, + image_sizes=image_sizes, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + return model_forward + + def phi3_rms_norm_forward(self, hidden_states): if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad): import linear_q4_0