LLM: fix qwen-vl interpolation gpu abnormal results. (#10457)
* fix qwen-vl interpolation gpu abnormal results. * fix style. * update qwen-vl gpu example. * fix comment and update example. * fix style.
This commit is contained in:
parent
e9055c32f9
commit
463a86cd5d
4 changed files with 82 additions and 21 deletions
|
|
@ -43,18 +43,9 @@ if __name__ == '__main__':
|
||||||
model = AutoModelForCausalLM.from_pretrained(model_path,
|
model = AutoModelForCausalLM.from_pretrained(model_path,
|
||||||
load_in_4bit=True,
|
load_in_4bit=True,
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
modules_to_not_convert=['c_fc', 'out_proj'])
|
modules_to_not_convert=['c_fc', 'out_proj'],
|
||||||
|
torch_dtype=torch.float32)
|
||||||
model = model.to('xpu')
|
model = model.to('xpu')
|
||||||
# Due to issue https://github.com/intel/intel-extension-for-pytorch/issues/454,
|
|
||||||
# currently put interpolation execution into cpu
|
|
||||||
def to_cpu(module, input, output):
|
|
||||||
return output.to("cpu")
|
|
||||||
|
|
||||||
def to_xpu(module, input):
|
|
||||||
return (input[0].to("xpu"),)
|
|
||||||
|
|
||||||
model.transformer.visual.ln_pre.register_forward_hook(to_cpu)
|
|
||||||
model.transformer.visual.transformer.register_forward_pre_hook(to_xpu)
|
|
||||||
|
|
||||||
# Specify hyperparameters for generation (No need to do this if you are using transformers>=4.32.0)
|
# Specify hyperparameters for generation (No need to do this if you are using transformers>=4.32.0)
|
||||||
model.generation_config = GenerationConfig.from_pretrained(model_path, trust_remote_code=True)
|
model.generation_config = GenerationConfig.from_pretrained(model_path, trust_remote_code=True)
|
||||||
|
|
|
||||||
|
|
@ -47,16 +47,6 @@ if __name__ == '__main__':
|
||||||
low_bit='sym_int4',
|
low_bit='sym_int4',
|
||||||
modules_to_not_convert=['c_fc', 'out_proj'])
|
modules_to_not_convert=['c_fc', 'out_proj'])
|
||||||
model = model.to('xpu')
|
model = model.to('xpu')
|
||||||
# Due to issue https://github.com/intel/intel-extension-for-pytorch/issues/454,
|
|
||||||
# currently put interpolation execution into cpu
|
|
||||||
def to_cpu(module, input, output):
|
|
||||||
return output.to("cpu")
|
|
||||||
|
|
||||||
def to_xpu(module, input):
|
|
||||||
return (input[0].to("xpu"),)
|
|
||||||
|
|
||||||
model.transformer.visual.ln_pre.register_forward_hook(to_cpu)
|
|
||||||
model.transformer.visual.transformer.register_forward_pre_hook(to_xpu)
|
|
||||||
|
|
||||||
# Specify hyperparameters for generation (No need to do this if you are using transformers>=4.32.0)
|
# Specify hyperparameters for generation (No need to do this if you are using transformers>=4.32.0)
|
||||||
model.generation_config = GenerationConfig.from_pretrained(model_path, trust_remote_code=True)
|
model.generation_config = GenerationConfig.from_pretrained(model_path, trust_remote_code=True)
|
||||||
|
|
|
||||||
|
|
@ -689,6 +689,23 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
|
||||||
|
|
||||||
if optimize_model:
|
if optimize_model:
|
||||||
model = _optimize_post(model, lightweight_bmm)
|
model = _optimize_post(model, lightweight_bmm)
|
||||||
|
|
||||||
|
if model.config.model_type == "qwen" and hasattr(model.config, "visual"):
|
||||||
|
# for Qwen-VL-Chat
|
||||||
|
# Due to issue https://github.com/intel/intel-extension-for-pytorch/issues/454,
|
||||||
|
# currently put interpolation execution into cpu
|
||||||
|
visual_module_name = model.transformer.visual.__class__.__module__
|
||||||
|
visual_module = importlib.import_module(visual_module_name)
|
||||||
|
from bigdl.llm.transformers.models.qwen_vl import qwen_vl_vision_transformer_forward
|
||||||
|
from bigdl.llm.transformers.models.qwen_vl import qwen_vl_resampler_forward
|
||||||
|
convert_forward(model,
|
||||||
|
visual_module.VisionTransformer,
|
||||||
|
qwen_vl_vision_transformer_forward
|
||||||
|
)
|
||||||
|
convert_forward(model,
|
||||||
|
visual_module.Resampler,
|
||||||
|
qwen_vl_resampler_forward
|
||||||
|
)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -47,6 +47,27 @@ def apply_rotary_pos_emb(t, freqs):
|
||||||
return torch.cat((t_, t_pass_), dim=-1).type_as(t)
|
return torch.cat((t_, t_pass_), dim=-1).type_as(t)
|
||||||
|
|
||||||
|
|
||||||
|
def get_abs_pos(abs_pos, tgt_size):
|
||||||
|
# abs_pos: L, C
|
||||||
|
# tgt_size: M
|
||||||
|
# return: M, C
|
||||||
|
src_size = int(math.sqrt(abs_pos.size(0)))
|
||||||
|
tgt_size = int(math.sqrt(tgt_size))
|
||||||
|
dtype = abs_pos.dtype
|
||||||
|
|
||||||
|
if src_size != tgt_size:
|
||||||
|
# Due to issue https://github.com/intel/intel-extension-for-pytorch/issues/454,
|
||||||
|
# currently put interpolation execution into cpu
|
||||||
|
return F.interpolate(
|
||||||
|
abs_pos.to("cpu").float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2),
|
||||||
|
size=(tgt_size, tgt_size),
|
||||||
|
mode="bicubic",
|
||||||
|
align_corners=False,
|
||||||
|
).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype).to(abs_pos.device)
|
||||||
|
else:
|
||||||
|
return abs_pos
|
||||||
|
|
||||||
|
|
||||||
def qwen_attention_forward_vl(
|
def qwen_attention_forward_vl(
|
||||||
self,
|
self,
|
||||||
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
||||||
|
|
@ -151,3 +172,45 @@ def qwen_attention_forward_vl(
|
||||||
outputs += (attn_weight,)
|
outputs += (attn_weight,)
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
def qwen_vl_resampler_forward(self, x, attn_mask=None):
|
||||||
|
|
||||||
|
pos_embed = get_abs_pos(self.pos_embed, x.size(1))
|
||||||
|
|
||||||
|
x = self.kv_proj(x)
|
||||||
|
x = self.ln_kv(x).permute(1, 0, 2)
|
||||||
|
|
||||||
|
N = x.shape[1]
|
||||||
|
q = self.ln_q(self.query)
|
||||||
|
out = self.attn(
|
||||||
|
self._repeat(q, N) + self.pos_embed.unsqueeze(1),
|
||||||
|
x + pos_embed.unsqueeze(1),
|
||||||
|
x,
|
||||||
|
attn_mask=attn_mask)[0]
|
||||||
|
return out.permute(1, 0, 2)
|
||||||
|
|
||||||
|
|
||||||
|
def qwen_vl_vision_transformer_forward(self, x: torch.Tensor):
|
||||||
|
x = x.to(
|
||||||
|
dtype=self.transformer.get_cast_dtype(),
|
||||||
|
device=self.transformer.get_cast_device(),
|
||||||
|
)
|
||||||
|
# to patches
|
||||||
|
x = self.conv1(x) # shape = [*, width, grid, grid]
|
||||||
|
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
|
||||||
|
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
||||||
|
|
||||||
|
x = x + get_abs_pos(self.positional_embedding, x.size(1))
|
||||||
|
|
||||||
|
x = self.ln_pre(x)
|
||||||
|
|
||||||
|
x = x.permute(1, 0, 2) # NLD -> LND
|
||||||
|
x = self.transformer(x)
|
||||||
|
x = x.permute(1, 0, 2) # LND -> NLD
|
||||||
|
|
||||||
|
x = self.attn_pool(x)
|
||||||
|
x = self.ln_post(x)
|
||||||
|
x = x @ self.proj
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue