LLM: fix qwen-vl interpolation gpu abnormal results. (#10457)
* fix qwen-vl interpolation gpu abnormal results. * fix style. * update qwen-vl gpu example. * fix comment and update example. * fix style.
This commit is contained in:
parent
e9055c32f9
commit
463a86cd5d
4 changed files with 82 additions and 21 deletions
|
|
@ -43,18 +43,9 @@ if __name__ == '__main__':
|
|||
model = AutoModelForCausalLM.from_pretrained(model_path,
|
||||
load_in_4bit=True,
|
||||
trust_remote_code=True,
|
||||
modules_to_not_convert=['c_fc', 'out_proj'])
|
||||
modules_to_not_convert=['c_fc', 'out_proj'],
|
||||
torch_dtype=torch.float32)
|
||||
model = model.to('xpu')
|
||||
# Due to issue https://github.com/intel/intel-extension-for-pytorch/issues/454,
|
||||
# currently put interpolation execution into cpu
|
||||
def to_cpu(module, input, output):
|
||||
return output.to("cpu")
|
||||
|
||||
def to_xpu(module, input):
|
||||
return (input[0].to("xpu"),)
|
||||
|
||||
model.transformer.visual.ln_pre.register_forward_hook(to_cpu)
|
||||
model.transformer.visual.transformer.register_forward_pre_hook(to_xpu)
|
||||
|
||||
# Specify hyperparameters for generation (No need to do this if you are using transformers>=4.32.0)
|
||||
model.generation_config = GenerationConfig.from_pretrained(model_path, trust_remote_code=True)
|
||||
|
|
|
|||
|
|
@ -47,16 +47,6 @@ if __name__ == '__main__':
|
|||
low_bit='sym_int4',
|
||||
modules_to_not_convert=['c_fc', 'out_proj'])
|
||||
model = model.to('xpu')
|
||||
# Due to issue https://github.com/intel/intel-extension-for-pytorch/issues/454,
|
||||
# currently put interpolation execution into cpu
|
||||
def to_cpu(module, input, output):
|
||||
return output.to("cpu")
|
||||
|
||||
def to_xpu(module, input):
|
||||
return (input[0].to("xpu"),)
|
||||
|
||||
model.transformer.visual.ln_pre.register_forward_hook(to_cpu)
|
||||
model.transformer.visual.transformer.register_forward_pre_hook(to_xpu)
|
||||
|
||||
# Specify hyperparameters for generation (No need to do this if you are using transformers>=4.32.0)
|
||||
model.generation_config = GenerationConfig.from_pretrained(model_path, trust_remote_code=True)
|
||||
|
|
|
|||
|
|
@ -689,6 +689,23 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
|
|||
|
||||
if optimize_model:
|
||||
model = _optimize_post(model, lightweight_bmm)
|
||||
|
||||
if model.config.model_type == "qwen" and hasattr(model.config, "visual"):
|
||||
# for Qwen-VL-Chat
|
||||
# Due to issue https://github.com/intel/intel-extension-for-pytorch/issues/454,
|
||||
# currently put interpolation execution into cpu
|
||||
visual_module_name = model.transformer.visual.__class__.__module__
|
||||
visual_module = importlib.import_module(visual_module_name)
|
||||
from bigdl.llm.transformers.models.qwen_vl import qwen_vl_vision_transformer_forward
|
||||
from bigdl.llm.transformers.models.qwen_vl import qwen_vl_resampler_forward
|
||||
convert_forward(model,
|
||||
visual_module.VisionTransformer,
|
||||
qwen_vl_vision_transformer_forward
|
||||
)
|
||||
convert_forward(model,
|
||||
visual_module.Resampler,
|
||||
qwen_vl_resampler_forward
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -47,6 +47,27 @@ def apply_rotary_pos_emb(t, freqs):
|
|||
return torch.cat((t_, t_pass_), dim=-1).type_as(t)
|
||||
|
||||
|
||||
def get_abs_pos(abs_pos, tgt_size):
|
||||
# abs_pos: L, C
|
||||
# tgt_size: M
|
||||
# return: M, C
|
||||
src_size = int(math.sqrt(abs_pos.size(0)))
|
||||
tgt_size = int(math.sqrt(tgt_size))
|
||||
dtype = abs_pos.dtype
|
||||
|
||||
if src_size != tgt_size:
|
||||
# Due to issue https://github.com/intel/intel-extension-for-pytorch/issues/454,
|
||||
# currently put interpolation execution into cpu
|
||||
return F.interpolate(
|
||||
abs_pos.to("cpu").float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2),
|
||||
size=(tgt_size, tgt_size),
|
||||
mode="bicubic",
|
||||
align_corners=False,
|
||||
).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype).to(abs_pos.device)
|
||||
else:
|
||||
return abs_pos
|
||||
|
||||
|
||||
def qwen_attention_forward_vl(
|
||||
self,
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
||||
|
|
@ -151,3 +172,45 @@ def qwen_attention_forward_vl(
|
|||
outputs += (attn_weight,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
def qwen_vl_resampler_forward(self, x, attn_mask=None):
|
||||
|
||||
pos_embed = get_abs_pos(self.pos_embed, x.size(1))
|
||||
|
||||
x = self.kv_proj(x)
|
||||
x = self.ln_kv(x).permute(1, 0, 2)
|
||||
|
||||
N = x.shape[1]
|
||||
q = self.ln_q(self.query)
|
||||
out = self.attn(
|
||||
self._repeat(q, N) + self.pos_embed.unsqueeze(1),
|
||||
x + pos_embed.unsqueeze(1),
|
||||
x,
|
||||
attn_mask=attn_mask)[0]
|
||||
return out.permute(1, 0, 2)
|
||||
|
||||
|
||||
def qwen_vl_vision_transformer_forward(self, x: torch.Tensor):
|
||||
x = x.to(
|
||||
dtype=self.transformer.get_cast_dtype(),
|
||||
device=self.transformer.get_cast_device(),
|
||||
)
|
||||
# to patches
|
||||
x = self.conv1(x) # shape = [*, width, grid, grid]
|
||||
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
|
||||
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
||||
|
||||
x = x + get_abs_pos(self.positional_embedding, x.size(1))
|
||||
|
||||
x = self.ln_pre(x)
|
||||
|
||||
x = x.permute(1, 0, 2) # NLD -> LND
|
||||
x = self.transformer(x)
|
||||
x = x.permute(1, 0, 2) # LND -> NLD
|
||||
|
||||
x = self.attn_pool(x)
|
||||
x = self.ln_post(x)
|
||||
x = x @ self.proj
|
||||
|
||||
return x
|
||||
|
|
|
|||
Loading…
Reference in a new issue