LLM: fix qwen-vl interpolation gpu abnormal results. (#10457)

* fix qwen-vl interpolation gpu abnormal results.

* fix style.

* update qwen-vl gpu example.

* fix comment and update example.

* fix style.
This commit is contained in:
Cengguang Zhang 2024-03-19 16:59:39 +08:00 committed by GitHub
parent e9055c32f9
commit 463a86cd5d
4 changed files with 82 additions and 21 deletions

View file

@ -43,18 +43,9 @@ if __name__ == '__main__':
model = AutoModelForCausalLM.from_pretrained(model_path,
load_in_4bit=True,
trust_remote_code=True,
modules_to_not_convert=['c_fc', 'out_proj'])
modules_to_not_convert=['c_fc', 'out_proj'],
torch_dtype=torch.float32)
model = model.to('xpu')
# Due to issue https://github.com/intel/intel-extension-for-pytorch/issues/454,
# currently put interpolation execution into cpu
def to_cpu(module, input, output):
return output.to("cpu")
def to_xpu(module, input):
return (input[0].to("xpu"),)
model.transformer.visual.ln_pre.register_forward_hook(to_cpu)
model.transformer.visual.transformer.register_forward_pre_hook(to_xpu)
# Specify hyperparameters for generation (No need to do this if you are using transformers>=4.32.0)
model.generation_config = GenerationConfig.from_pretrained(model_path, trust_remote_code=True)

View file

@ -47,16 +47,6 @@ if __name__ == '__main__':
low_bit='sym_int4',
modules_to_not_convert=['c_fc', 'out_proj'])
model = model.to('xpu')
# Due to issue https://github.com/intel/intel-extension-for-pytorch/issues/454,
# currently put interpolation execution into cpu
def to_cpu(module, input, output):
return output.to("cpu")
def to_xpu(module, input):
return (input[0].to("xpu"),)
model.transformer.visual.ln_pre.register_forward_hook(to_cpu)
model.transformer.visual.transformer.register_forward_pre_hook(to_xpu)
# Specify hyperparameters for generation (No need to do this if you are using transformers>=4.32.0)
model.generation_config = GenerationConfig.from_pretrained(model_path, trust_remote_code=True)

View file

@ -689,6 +689,23 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
if optimize_model:
model = _optimize_post(model, lightweight_bmm)
if model.config.model_type == "qwen" and hasattr(model.config, "visual"):
# for Qwen-VL-Chat
# Due to issue https://github.com/intel/intel-extension-for-pytorch/issues/454,
# currently put interpolation execution into cpu
visual_module_name = model.transformer.visual.__class__.__module__
visual_module = importlib.import_module(visual_module_name)
from bigdl.llm.transformers.models.qwen_vl import qwen_vl_vision_transformer_forward
from bigdl.llm.transformers.models.qwen_vl import qwen_vl_resampler_forward
convert_forward(model,
visual_module.VisionTransformer,
qwen_vl_vision_transformer_forward
)
convert_forward(model,
visual_module.Resampler,
qwen_vl_resampler_forward
)
return model

View file

@ -47,6 +47,27 @@ def apply_rotary_pos_emb(t, freqs):
return torch.cat((t_, t_pass_), dim=-1).type_as(t)
def get_abs_pos(abs_pos, tgt_size):
# abs_pos: L, C
# tgt_size: M
# return: M, C
src_size = int(math.sqrt(abs_pos.size(0)))
tgt_size = int(math.sqrt(tgt_size))
dtype = abs_pos.dtype
if src_size != tgt_size:
# Due to issue https://github.com/intel/intel-extension-for-pytorch/issues/454,
# currently put interpolation execution into cpu
return F.interpolate(
abs_pos.to("cpu").float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2),
size=(tgt_size, tgt_size),
mode="bicubic",
align_corners=False,
).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype).to(abs_pos.device)
else:
return abs_pos
def qwen_attention_forward_vl(
self,
hidden_states: Optional[Tuple[torch.FloatTensor]],
@ -151,3 +172,45 @@ def qwen_attention_forward_vl(
outputs += (attn_weight,)
return outputs
def qwen_vl_resampler_forward(self, x, attn_mask=None):
pos_embed = get_abs_pos(self.pos_embed, x.size(1))
x = self.kv_proj(x)
x = self.ln_kv(x).permute(1, 0, 2)
N = x.shape[1]
q = self.ln_q(self.query)
out = self.attn(
self._repeat(q, N) + self.pos_embed.unsqueeze(1),
x + pos_embed.unsqueeze(1),
x,
attn_mask=attn_mask)[0]
return out.permute(1, 0, 2)
def qwen_vl_vision_transformer_forward(self, x: torch.Tensor):
x = x.to(
dtype=self.transformer.get_cast_dtype(),
device=self.transformer.get_cast_device(),
)
# to patches
x = self.conv1(x) # shape = [*, width, grid, grid]
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
x = x + get_abs_pos(self.positional_embedding, x.size(1))
x = self.ln_pre(x)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.attn_pool(x)
x = self.ln_post(x)
x = x @ self.proj
return x