LLM: fix qwen-vl interpolation gpu abnormal results. (#10457)
* fix qwen-vl interpolation gpu abnormal results. * fix style. * update qwen-vl gpu example. * fix comment and update example. * fix style.
This commit is contained in:
		
							parent
							
								
									e9055c32f9
								
							
						
					
					
						commit
						463a86cd5d
					
				
					 4 changed files with 82 additions and 21 deletions
				
			
		| 
						 | 
				
			
			@ -43,18 +43,9 @@ if __name__ == '__main__':
 | 
			
		|||
    model = AutoModelForCausalLM.from_pretrained(model_path, 
 | 
			
		||||
                                                 load_in_4bit=True, 
 | 
			
		||||
                                                 trust_remote_code=True, 
 | 
			
		||||
                                                 modules_to_not_convert=['c_fc', 'out_proj'])
 | 
			
		||||
                                                 modules_to_not_convert=['c_fc', 'out_proj'],
 | 
			
		||||
                                                 torch_dtype=torch.float32)
 | 
			
		||||
    model = model.to('xpu')
 | 
			
		||||
    # Due to issue https://github.com/intel/intel-extension-for-pytorch/issues/454,
 | 
			
		||||
    # currently put interpolation execution into cpu
 | 
			
		||||
    def to_cpu(module, input, output):
 | 
			
		||||
        return output.to("cpu")
 | 
			
		||||
 | 
			
		||||
    def to_xpu(module, input):
 | 
			
		||||
        return (input[0].to("xpu"),)
 | 
			
		||||
 | 
			
		||||
    model.transformer.visual.ln_pre.register_forward_hook(to_cpu)
 | 
			
		||||
    model.transformer.visual.transformer.register_forward_pre_hook(to_xpu)
 | 
			
		||||
 | 
			
		||||
    # Specify hyperparameters for generation (No need to do this if you are using transformers>=4.32.0)
 | 
			
		||||
    model.generation_config = GenerationConfig.from_pretrained(model_path, trust_remote_code=True)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -47,16 +47,6 @@ if __name__ == '__main__':
 | 
			
		|||
                           low_bit='sym_int4', 
 | 
			
		||||
                           modules_to_not_convert=['c_fc', 'out_proj'])
 | 
			
		||||
    model = model.to('xpu')
 | 
			
		||||
    # Due to issue https://github.com/intel/intel-extension-for-pytorch/issues/454,
 | 
			
		||||
    # currently put interpolation execution into cpu
 | 
			
		||||
    def to_cpu(module, input, output):
 | 
			
		||||
        return output.to("cpu")
 | 
			
		||||
 | 
			
		||||
    def to_xpu(module, input):
 | 
			
		||||
        return (input[0].to("xpu"),)
 | 
			
		||||
 | 
			
		||||
    model.transformer.visual.ln_pre.register_forward_hook(to_cpu)
 | 
			
		||||
    model.transformer.visual.transformer.register_forward_pre_hook(to_xpu)
 | 
			
		||||
 | 
			
		||||
    # Specify hyperparameters for generation (No need to do this if you are using transformers>=4.32.0)
 | 
			
		||||
    model.generation_config = GenerationConfig.from_pretrained(model_path, trust_remote_code=True)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -689,6 +689,23 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
 | 
			
		|||
 | 
			
		||||
    if optimize_model:
 | 
			
		||||
        model = _optimize_post(model, lightweight_bmm)
 | 
			
		||||
 | 
			
		||||
    if model.config.model_type == "qwen" and hasattr(model.config, "visual"):
 | 
			
		||||
        # for Qwen-VL-Chat
 | 
			
		||||
        # Due to issue https://github.com/intel/intel-extension-for-pytorch/issues/454,
 | 
			
		||||
        # currently put interpolation execution into cpu
 | 
			
		||||
        visual_module_name = model.transformer.visual.__class__.__module__
 | 
			
		||||
        visual_module = importlib.import_module(visual_module_name)
 | 
			
		||||
        from bigdl.llm.transformers.models.qwen_vl import qwen_vl_vision_transformer_forward
 | 
			
		||||
        from bigdl.llm.transformers.models.qwen_vl import qwen_vl_resampler_forward
 | 
			
		||||
        convert_forward(model,
 | 
			
		||||
                        visual_module.VisionTransformer,
 | 
			
		||||
                        qwen_vl_vision_transformer_forward
 | 
			
		||||
                        )
 | 
			
		||||
        convert_forward(model,
 | 
			
		||||
                        visual_module.Resampler,
 | 
			
		||||
                        qwen_vl_resampler_forward
 | 
			
		||||
                        )
 | 
			
		||||
    return model
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -47,6 +47,27 @@ def apply_rotary_pos_emb(t, freqs):
 | 
			
		|||
    return torch.cat((t_, t_pass_), dim=-1).type_as(t)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_abs_pos(abs_pos, tgt_size):
 | 
			
		||||
    # abs_pos: L, C
 | 
			
		||||
    # tgt_size: M
 | 
			
		||||
    # return: M, C
 | 
			
		||||
    src_size = int(math.sqrt(abs_pos.size(0)))
 | 
			
		||||
    tgt_size = int(math.sqrt(tgt_size))
 | 
			
		||||
    dtype = abs_pos.dtype
 | 
			
		||||
 | 
			
		||||
    if src_size != tgt_size:
 | 
			
		||||
        # Due to issue https://github.com/intel/intel-extension-for-pytorch/issues/454,
 | 
			
		||||
        # currently put interpolation execution into cpu
 | 
			
		||||
        return F.interpolate(
 | 
			
		||||
            abs_pos.to("cpu").float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2),
 | 
			
		||||
            size=(tgt_size, tgt_size),
 | 
			
		||||
            mode="bicubic",
 | 
			
		||||
            align_corners=False,
 | 
			
		||||
        ).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype).to(abs_pos.device)
 | 
			
		||||
    else:
 | 
			
		||||
        return abs_pos
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def qwen_attention_forward_vl(
 | 
			
		||||
    self,
 | 
			
		||||
    hidden_states: Optional[Tuple[torch.FloatTensor]],
 | 
			
		||||
| 
						 | 
				
			
			@ -151,3 +172,45 @@ def qwen_attention_forward_vl(
 | 
			
		|||
        outputs += (attn_weight,)
 | 
			
		||||
 | 
			
		||||
    return outputs
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def qwen_vl_resampler_forward(self, x, attn_mask=None):
 | 
			
		||||
 | 
			
		||||
    pos_embed = get_abs_pos(self.pos_embed, x.size(1))
 | 
			
		||||
 | 
			
		||||
    x = self.kv_proj(x)
 | 
			
		||||
    x = self.ln_kv(x).permute(1, 0, 2)
 | 
			
		||||
 | 
			
		||||
    N = x.shape[1]
 | 
			
		||||
    q = self.ln_q(self.query)
 | 
			
		||||
    out = self.attn(
 | 
			
		||||
        self._repeat(q, N) + self.pos_embed.unsqueeze(1),
 | 
			
		||||
        x + pos_embed.unsqueeze(1),
 | 
			
		||||
        x,
 | 
			
		||||
        attn_mask=attn_mask)[0]
 | 
			
		||||
    return out.permute(1, 0, 2)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def qwen_vl_vision_transformer_forward(self, x: torch.Tensor):
 | 
			
		||||
    x = x.to(
 | 
			
		||||
        dtype=self.transformer.get_cast_dtype(),
 | 
			
		||||
        device=self.transformer.get_cast_device(),
 | 
			
		||||
    )
 | 
			
		||||
    # to patches
 | 
			
		||||
    x = self.conv1(x)  # shape = [*, width, grid, grid]
 | 
			
		||||
    x = x.reshape(x.shape[0], x.shape[1], -1)  # shape = [*, width, grid ** 2]
 | 
			
		||||
    x = x.permute(0, 2, 1)  # shape = [*, grid ** 2, width]
 | 
			
		||||
 | 
			
		||||
    x = x + get_abs_pos(self.positional_embedding, x.size(1))
 | 
			
		||||
 | 
			
		||||
    x = self.ln_pre(x)
 | 
			
		||||
 | 
			
		||||
    x = x.permute(1, 0, 2)  # NLD -> LND
 | 
			
		||||
    x = self.transformer(x)
 | 
			
		||||
    x = x.permute(1, 0, 2)  # LND -> NLD
 | 
			
		||||
 | 
			
		||||
    x = self.attn_pool(x)
 | 
			
		||||
    x = self.ln_post(x)
 | 
			
		||||
    x = x @ self.proj
 | 
			
		||||
 | 
			
		||||
    return x
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue