diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index 22989304..994c822b 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -1066,7 +1066,7 @@ def _optimize_pre(model, qtype=None): from ipex_llm.transformers.models.baichuan_m1 import pre_register_inv_freq model.apply(pre_register_inv_freq) elif model.config.model_type == "multi_modality": - pass + _optimize_pre(model.language_model) return model @@ -2012,8 +2012,10 @@ def _optimize_post(model): # vision vpm_modeling_module_name = model.vision_model.vision_tower.__class__.__module__ vpm_module = importlib.import_module(vpm_modeling_module_name) - from ipex_llm.transformers.models.janus import vision_attention_forward convert_forward(model.vision_model, vpm_module.Attention, vision_attention_forward) + # llm + _optimize_post(model.language_model) + return model diff --git a/python/llm/src/ipex_llm/transformers/models/janus.py b/python/llm/src/ipex_llm/transformers/models/janus.py new file mode 100644 index 00000000..e9d7ba0b --- /dev/null +++ b/python/llm/src/ipex_llm/transformers/models/janus.py @@ -0,0 +1,49 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This file is adapted from +# https://github.com/deepseek-ai/Janus/blob/main/janus/models/siglip_vit.py + +import torch + +from ipex_llm.transformers.models.common import scaled_dot_product_attention + + +def vision_attention_forward(self, x: torch.Tensor) -> torch.Tensor: + B, N, C = x.shape + qkv = ( + self.qkv(x) + .reshape(B, N, 3, self.num_heads, self.head_dim) + .permute(2, 0, 3, 1, 4) + ) + q, k, v = qkv.unbind(0) + q, k = self.q_norm(q), self.k_norm(k) + + if self.fused_attn: + # ipex-llm opt: sdpa + x = scaled_dot_product_attention( + q, k.contiguous(), v.contiguous(), None, False + ) + else: + q = q * self.scale + attn = q @ k.transpose(-2, -1) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = attn @ v + + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x diff --git a/python/llm/src/ipex_llm/transformers/models/utils.py b/python/llm/src/ipex_llm/transformers/models/utils.py index 0dbea2e9..0e3e897c 100644 --- a/python/llm/src/ipex_llm/transformers/models/utils.py +++ b/python/llm/src/ipex_llm/transformers/models/utils.py @@ -86,7 +86,7 @@ def use_quantize_kv_cache(linear: torch.nn.Module, x: torch.Tensor, return os.environ["IPEX_LLM_QUANTIZE_KV_CACHE"] == "1" elif os.environ.get("IPEX_LLM_LOW_MEM", None) is not None: return os.environ["IPEX_LLM_LOW_MEM"] == "1" - elif linear.qtype in [ggml_tensor_qtype["fp16"], ggml_tensor_qtype["bf16"]]: + elif linear.weight.dtype != torch.uint8: # unquantized return False else: device_name = get_xpu_device_name(x.device)