diff --git a/python/llm/example/GPU/HF-Transformers-AutoModels/Model/qwen-vl/chat.py b/python/llm/example/GPU/HF-Transformers-AutoModels/Model/qwen-vl/chat.py index 9df8cf46..4701eb5a 100644 --- a/python/llm/example/GPU/HF-Transformers-AutoModels/Model/qwen-vl/chat.py +++ b/python/llm/example/GPU/HF-Transformers-AutoModels/Model/qwen-vl/chat.py @@ -43,18 +43,9 @@ if __name__ == '__main__': model = AutoModelForCausalLM.from_pretrained(model_path, load_in_4bit=True, trust_remote_code=True, - modules_to_not_convert=['c_fc', 'out_proj']) + modules_to_not_convert=['c_fc', 'out_proj'], + torch_dtype=torch.float32) model = model.to('xpu') - # Due to issue https://github.com/intel/intel-extension-for-pytorch/issues/454, - # currently put interpolation execution into cpu - def to_cpu(module, input, output): - return output.to("cpu") - - def to_xpu(module, input): - return (input[0].to("xpu"),) - - model.transformer.visual.ln_pre.register_forward_hook(to_cpu) - model.transformer.visual.transformer.register_forward_pre_hook(to_xpu) # Specify hyperparameters for generation (No need to do this if you are using transformers>=4.32.0) model.generation_config = GenerationConfig.from_pretrained(model_path, trust_remote_code=True) diff --git a/python/llm/example/GPU/PyTorch-Models/Model/qwen-vl/chat.py b/python/llm/example/GPU/PyTorch-Models/Model/qwen-vl/chat.py index df869916..29adf173 100644 --- a/python/llm/example/GPU/PyTorch-Models/Model/qwen-vl/chat.py +++ b/python/llm/example/GPU/PyTorch-Models/Model/qwen-vl/chat.py @@ -47,16 +47,6 @@ if __name__ == '__main__': low_bit='sym_int4', modules_to_not_convert=['c_fc', 'out_proj']) model = model.to('xpu') - # Due to issue https://github.com/intel/intel-extension-for-pytorch/issues/454, - # currently put interpolation execution into cpu - def to_cpu(module, input, output): - return output.to("cpu") - - def to_xpu(module, input): - return (input[0].to("xpu"),) - - model.transformer.visual.ln_pre.register_forward_hook(to_cpu) - model.transformer.visual.transformer.register_forward_pre_hook(to_xpu) # Specify hyperparameters for generation (No need to do this if you are using transformers>=4.32.0) model.generation_config = GenerationConfig.from_pretrained(model_path, trust_remote_code=True) diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index 83b96f88..7ea3f4e2 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -689,6 +689,23 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True, if optimize_model: model = _optimize_post(model, lightweight_bmm) + + if model.config.model_type == "qwen" and hasattr(model.config, "visual"): + # for Qwen-VL-Chat + # Due to issue https://github.com/intel/intel-extension-for-pytorch/issues/454, + # currently put interpolation execution into cpu + visual_module_name = model.transformer.visual.__class__.__module__ + visual_module = importlib.import_module(visual_module_name) + from bigdl.llm.transformers.models.qwen_vl import qwen_vl_vision_transformer_forward + from bigdl.llm.transformers.models.qwen_vl import qwen_vl_resampler_forward + convert_forward(model, + visual_module.VisionTransformer, + qwen_vl_vision_transformer_forward + ) + convert_forward(model, + visual_module.Resampler, + qwen_vl_resampler_forward + ) return model diff --git a/python/llm/src/bigdl/llm/transformers/models/qwen_vl.py b/python/llm/src/bigdl/llm/transformers/models/qwen_vl.py index 82fcf90b..4094ae14 100644 --- a/python/llm/src/bigdl/llm/transformers/models/qwen_vl.py +++ b/python/llm/src/bigdl/llm/transformers/models/qwen_vl.py @@ -47,6 +47,27 @@ def apply_rotary_pos_emb(t, freqs): return torch.cat((t_, t_pass_), dim=-1).type_as(t) +def get_abs_pos(abs_pos, tgt_size): + # abs_pos: L, C + # tgt_size: M + # return: M, C + src_size = int(math.sqrt(abs_pos.size(0))) + tgt_size = int(math.sqrt(tgt_size)) + dtype = abs_pos.dtype + + if src_size != tgt_size: + # Due to issue https://github.com/intel/intel-extension-for-pytorch/issues/454, + # currently put interpolation execution into cpu + return F.interpolate( + abs_pos.to("cpu").float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2), + size=(tgt_size, tgt_size), + mode="bicubic", + align_corners=False, + ).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype).to(abs_pos.device) + else: + return abs_pos + + def qwen_attention_forward_vl( self, hidden_states: Optional[Tuple[torch.FloatTensor]], @@ -151,3 +172,45 @@ def qwen_attention_forward_vl( outputs += (attn_weight,) return outputs + + +def qwen_vl_resampler_forward(self, x, attn_mask=None): + + pos_embed = get_abs_pos(self.pos_embed, x.size(1)) + + x = self.kv_proj(x) + x = self.ln_kv(x).permute(1, 0, 2) + + N = x.shape[1] + q = self.ln_q(self.query) + out = self.attn( + self._repeat(q, N) + self.pos_embed.unsqueeze(1), + x + pos_embed.unsqueeze(1), + x, + attn_mask=attn_mask)[0] + return out.permute(1, 0, 2) + + +def qwen_vl_vision_transformer_forward(self, x: torch.Tensor): + x = x.to( + dtype=self.transformer.get_cast_dtype(), + device=self.transformer.get_cast_device(), + ) + # to patches + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + + x = x + get_abs_pos(self.positional_embedding, x.size(1)) + + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + + x = self.attn_pool(x) + x = self.ln_post(x) + x = x @ self.proj + + return x