Support vpm and resampler module of minicpm-v on NPU (#12375)
This commit is contained in:
parent
85c9279e6e
commit
7a97fbb779
5 changed files with 592 additions and 15 deletions
|
|
@ -633,7 +633,7 @@ def transformers_int4_npu_win(repo_id,
|
|||
model = AutoModel.from_pretrained(model_path, load_in_low_bit=low_bit, optimize_model=optimize_model,
|
||||
trust_remote_code=True, use_cache=True, max_context_len=max_context_len, max_prompt_len=int(in_out_len[0]),
|
||||
quantization_group_size=npu_group_size, transpose_value_cache=transpose_value_cache,
|
||||
attn_implementation="eager", modules_to_not_convert=["vpm", "resampler"]).eval()
|
||||
attn_implementation="eager", torch_dtype=torch.float16).eval()
|
||||
model = model.llm
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -56,7 +56,7 @@ if __name__ == "__main__":
|
|||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path,
|
||||
torch_dtype=torch.float32,
|
||||
torch_dtype=torch.float16,
|
||||
trust_remote_code=True,
|
||||
attn_implementation="eager",
|
||||
load_in_low_bit="sym_int4",
|
||||
|
|
@ -66,7 +66,6 @@ if __name__ == "__main__":
|
|||
intra_pp=args.intra_pp,
|
||||
inter_pp=args.inter_pp,
|
||||
transpose_value_cache=not args.disable_transpose_value_cache,
|
||||
modules_to_not_convert=['vpm', 'resampler']
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
|
||||
|
|
|
|||
|
|
@ -47,7 +47,7 @@ if __name__ == '__main__':
|
|||
image_path = args.image_url_or_path
|
||||
|
||||
model = AutoModel.from_pretrained(model_path,
|
||||
torch_dtype=torch.float32,
|
||||
torch_dtype=torch.float16,
|
||||
trust_remote_code=True,
|
||||
attn_implementation="eager",
|
||||
load_in_low_bit="sym_int4",
|
||||
|
|
@ -57,8 +57,7 @@ if __name__ == '__main__':
|
|||
intra_pp=args.intra_pp,
|
||||
inter_pp=args.inter_pp,
|
||||
transpose_value_cache=not args.disable_transpose_value_cache,
|
||||
modules_to_not_convert=['vpm', 'resampler']
|
||||
)
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path,
|
||||
trust_remote_code=True)
|
||||
model.eval()
|
||||
|
|
|
|||
|
|
@ -46,12 +46,7 @@ def optimize_llm_pre(model: torch.nn.Module, qtype, mixed_precision,
|
|||
from ipex_llm.transformers.models.baichuan import pre_compute_inv_freq
|
||||
model.apply(pre_compute_inv_freq)
|
||||
|
||||
# MiniCPM-V 2.6 must put lm_head on CPU now
|
||||
cpu_lm_head = (
|
||||
(model.config.model_type == "minicpmv" and model.config.hidden_size == 3584 and
|
||||
model.config.vocab_size == 151666)
|
||||
or os.environ.get("IPEX_LLM_CPU_LM_HEAD", "0") != "0"
|
||||
)
|
||||
cpu_lm_head = os.environ.get("IPEX_LLM_CPU_LM_HEAD", "0") != "0"
|
||||
|
||||
# workaround for MiniCPM-2B
|
||||
if model.config.model_type == "minicpm" and model.config.num_hidden_layers == 40:
|
||||
|
|
@ -76,6 +71,48 @@ def optimize_llm_pre(model: torch.nn.Module, qtype, mixed_precision,
|
|||
|
||||
if model.config.model_type == "minicpmv" and hasattr(model, "llm"):
|
||||
# MiniCPM-V
|
||||
# convert conv2d and layernorm
|
||||
from ipex_llm.transformers.npu_models.minicpmv_mp import MinicpmVPatchEmbedding, \
|
||||
replace_with_Layernorm
|
||||
origin_conv = model.vpm.embeddings.patch_embedding
|
||||
new_conv = MinicpmVPatchEmbedding(
|
||||
weight=origin_conv.weight.to(torch.float16),
|
||||
bias=origin_conv.bias.to(torch.float16),
|
||||
strides=model.config.vision_config.patch_size,
|
||||
)
|
||||
model.vpm.embeddings.patch_embedding = new_conv
|
||||
del new_conv
|
||||
replace_with_Layernorm(model, qtype=None, device='NPU',
|
||||
modules_to_not_convert=[], group_size=0)
|
||||
|
||||
# replace forward function
|
||||
from ipex_llm.transformers.npu_models.minicpmv_mp import pad_mlp_fc2, pad_mlp_forward, \
|
||||
encoder_attn_forward, multi_head_attn_forward, resampler_forward
|
||||
model.apply(pad_mlp_fc2) # pad mlp.fc2 to avoid compile error
|
||||
modeling_module_name = model.__class__.__module__
|
||||
module = importlib.import_module(modeling_module_name)
|
||||
setattr(module.Resampler, "forward", resampler_forward)
|
||||
module = importlib.import_module(modeling_module_name.replace("modeling_minicpmv",
|
||||
"resampler"))
|
||||
setattr(module.MultiheadAttention, "multi_head_attention_forward", multi_head_attn_forward)
|
||||
if model.config.hidden_size == 3584 and model.config.vocab_size == 151666:
|
||||
# MiniCPM-V 2.6
|
||||
module = importlib.import_module(modeling_module_name.replace("modeling_minicpmv",
|
||||
"modeling_navit_siglip"))
|
||||
setattr(module.SiglipAttention, "forward", encoder_attn_forward)
|
||||
setattr(module.SiglipMLP, "forward", pad_mlp_forward)
|
||||
|
||||
# workaround for lm_head on NPU
|
||||
from ipex_llm.transformers.npu_models.minicpmv_mp import pad_lm_head, lm_head_forward
|
||||
model.apply(pad_lm_head) # pad lm_head to avoid compile error
|
||||
setattr(model.llm.lm_head, "forward", lm_head_forward)
|
||||
elif model.config.hidden_size == 4096 and model.config.vocab_size == 128256:
|
||||
# MiniCPM-V 2.5
|
||||
from transformers.models.idefics2.modeling_idefics2 import Idefics2VisionMLP, \
|
||||
Idefics2VisionAttention
|
||||
convert_forward(model, Idefics2VisionAttention, encoder_attn_forward)
|
||||
convert_forward(model, Idefics2VisionMLP, pad_mlp_forward)
|
||||
|
||||
if model.config.hidden_size == 2304 and model.config.vocab_size == 122753:
|
||||
# MiniCPM-V 2
|
||||
model.llm.config.model_type = "minicpm"
|
||||
|
|
@ -126,9 +163,9 @@ def optimize_llm_pre(model: torch.nn.Module, qtype, mixed_precision,
|
|||
model.lm_head = new_lm_head
|
||||
|
||||
if model.config.model_type == "qwen2":
|
||||
# for Qwen2-7B-Insturct, divide lm_head into 14 parts
|
||||
if model.config.hidden_size == 3584 and model.config.vocab_size == 152064 and \
|
||||
not cpu_lm_head:
|
||||
# for Qwen2-7B-Insturct and MiniCPM-V 2.6, divide lm_head into 14 parts
|
||||
if model.config.hidden_size == 3584 and (model.config.vocab_size == 152064 or
|
||||
model.config.vocab_size == 151666) and not cpu_lm_head:
|
||||
# Do not split lm_head and use sym_int8 instead when mixed_precison is True
|
||||
if quantization_group_size == 0:
|
||||
# Do not split lm_head and use sym_int8 instead when mixed_precison is True
|
||||
|
|
|
|||
542
python/llm/src/ipex_llm/transformers/npu_models/minicpmv_mp.py
Normal file
542
python/llm/src/ipex_llm/transformers/npu_models/minicpmv_mp.py
Normal file
|
|
@ -0,0 +1,542 @@
|
|||
#
|
||||
# Copyright 2016 The BigDL Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# Some parts of this file is adapted from
|
||||
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/idefics2/modeling_idefics2.py
|
||||
# which is licensed under Apache License 2.0:
|
||||
#
|
||||
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# Some parts of this file is adapted from
|
||||
# https://huggingface.co/openbmb/MiniCPM-V-2_6/blob/main/resampler.py
|
||||
# which is licensed under Apache License 2.0:
|
||||
#
|
||||
# Copyright 2024 OpenBMB
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from typing import Optional, Tuple
|
||||
from ipex_llm.utils.common.log4Error import invalidInputError
|
||||
from torch import Tensor
|
||||
import warnings
|
||||
from torch.nn.functional import *
|
||||
from torch.nn.modules.activation import *
|
||||
from intel_npu_acceleration_library.backend.factory import NNFactory
|
||||
import numpy as np
|
||||
from functools import partial
|
||||
import uuid
|
||||
from ipex_llm.transformers.npu_models.mp_models_base import run_model
|
||||
from ipex_llm.transformers.npu_models.convert import module_optimization
|
||||
|
||||
|
||||
class MinicpmVConv2d(NNFactory):
|
||||
def __init__(
|
||||
self,
|
||||
input_shape,
|
||||
weight_shape,
|
||||
bias,
|
||||
strides,
|
||||
padding,
|
||||
dilation,
|
||||
groups,
|
||||
device: str = "NPU",
|
||||
):
|
||||
super().__init__(False, device)
|
||||
|
||||
# define input
|
||||
input = self.parameter(input_shape, dtype=np.float16)
|
||||
weight = self.parameter(weight_shape, dtype=np.float16)
|
||||
if bias is not None:
|
||||
bias_node = self.parameter((1, weight_shape[0], 1, 1), dtype=np.float16)
|
||||
else:
|
||||
bias_node = None
|
||||
|
||||
input = self.concat(input, input, axis=2) # current workaround for compile error
|
||||
res = self.convolution(input_node=input,
|
||||
weights_node=weight,
|
||||
bias=bias_node,
|
||||
strides=strides,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups)
|
||||
res = self.slice(res, begin=[0, 0, 0, 0],
|
||||
end=[res.shape[0], res.shape[1], 1, res.shape[3]])
|
||||
# define outputs
|
||||
res = self.convert_to_fp16(res)
|
||||
|
||||
print("start compiling")
|
||||
self.compile()
|
||||
|
||||
|
||||
class MinicpmVPatchEmbedding(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
weight,
|
||||
bias,
|
||||
strides=1,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.op_id = str(uuid.uuid4())
|
||||
self.parameters = [weight]
|
||||
if bias is not None:
|
||||
self.parameters.append(bias)
|
||||
self.backend_cls = partial(
|
||||
MinicpmVConv2d,
|
||||
weight_shape=weight.shape,
|
||||
bias=bias,
|
||||
strides=strides,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.to(torch.float16)
|
||||
return run_model(x, self.parameters, self.backend_cls, self.op_id)
|
||||
|
||||
|
||||
class LayerNorm(NNFactory):
|
||||
def __init__(
|
||||
self,
|
||||
input_shape,
|
||||
weight_shape,
|
||||
bias_shape,
|
||||
eps,
|
||||
device: str = "NPU",
|
||||
):
|
||||
super().__init__(False, device)
|
||||
|
||||
# define input
|
||||
input = self.parameter(input_shape, dtype=np.float16)
|
||||
weight = self.parameter(weight_shape, dtype=np.float16)
|
||||
bias = self.parameter(bias_shape, dtype=np.float16)
|
||||
|
||||
input = self.convert_to_fp32(input)
|
||||
mean_res = self.reduce_mean(input, -1, keep_dims=True,)
|
||||
variance = self.reduce_mean(
|
||||
self.power(input - mean_res, self.constant(np.array([[2]], dtype=np.float32))),
|
||||
-1,
|
||||
keep_dims=True,
|
||||
)
|
||||
eps = self.constant(eps)
|
||||
input = self.eltwise_div(input - mean_res, self.sqrt(self.eltwise_add(variance, eps)))
|
||||
weight = self.convert_to_fp32(weight)
|
||||
input = self.eltwise_mul(weight, input)
|
||||
bias = self.convert_to_fp32(bias)
|
||||
input = self.eltwise_add(bias, input)
|
||||
|
||||
# define outputs
|
||||
input = self.convert_to_fp16(input)
|
||||
|
||||
print("start compiling")
|
||||
self.compile()
|
||||
|
||||
|
||||
class MinicpmVLayerNorm(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
weight,
|
||||
bias,
|
||||
eps=1e-6,
|
||||
):
|
||||
super().__init__()
|
||||
self.op_id = str(uuid.uuid4())
|
||||
self.parameters = [weight, bias]
|
||||
self.backend_cls = partial(
|
||||
LayerNorm,
|
||||
weight_shape=weight.shape,
|
||||
bias_shape=bias.shape,
|
||||
eps=eps,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.to(torch.float16)
|
||||
return run_model(x, self.parameters, self.backend_cls, self.op_id)
|
||||
|
||||
|
||||
@module_optimization
|
||||
def replace_with_Layernorm(layer, qtype=None, device='NPU',
|
||||
modules_to_not_convert=[], group_size=0):
|
||||
if isinstance(layer, torch.nn.LayerNorm):
|
||||
return MinicpmVLayerNorm(
|
||||
weight=layer.weight.to(torch.float16),
|
||||
bias=layer.bias.to(torch.float16),
|
||||
)
|
||||
|
||||
|
||||
def pad_mlp_fc2(module: torch.nn.Module):
|
||||
if hasattr(module, 'fc2') and module.fc2.in_features == 4304:
|
||||
new_linear = torch.nn.Linear(0, 0, bias=True)
|
||||
padded_weight = torch.cat((module.fc2.weight, module.fc2.weight[:, :(1152*4-4304)]), dim=1)
|
||||
new_weight = torch.nn.Parameter(padded_weight, requires_grad=False)
|
||||
new_linear.weight = new_weight
|
||||
new_linear.bias = module.fc2.bias
|
||||
new_linear.in_features = new_weight.size(1)
|
||||
new_linear.out_features = new_weight.size(0)
|
||||
module.fc2 = new_linear
|
||||
del new_linear
|
||||
|
||||
|
||||
def pad_mlp_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.fc1(hidden_states)
|
||||
hidden_states = self.activation_fn(hidden_states)
|
||||
hidden_states = F.pad(hidden_states,
|
||||
(0, (1152*4-4304), 0, 0, 0, 0))
|
||||
hidden_states = self.fc2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
def pad_lm_head(module: torch.nn.Module):
|
||||
if hasattr(module, 'lm_head') and module.lm_head.in_features == 3584 \
|
||||
and module.lm_head.out_features == 151666:
|
||||
new_linear = torch.nn.Linear(0, 0, bias=False)
|
||||
padded_weight = F.pad(module.lm_head.weight,
|
||||
(0, 0, 0, 152064-151666)) # 152064 is qwen2-7b vocab_size
|
||||
new_weight = torch.nn.Parameter(padded_weight, requires_grad=False)
|
||||
new_linear.weight = new_weight
|
||||
new_linear.in_features = new_weight.size(1)
|
||||
new_linear.out_features = new_weight.size(0)
|
||||
module.lm_head = new_linear
|
||||
del new_linear
|
||||
|
||||
|
||||
def lm_head_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self(hidden_states)
|
||||
hidden_states = hidden_states[:, :, :151666]
|
||||
return hidden_states
|
||||
|
||||
|
||||
def encoder_attn_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
batch_size, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(batch_size, q_len,
|
||||
self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(batch_size, q_len,
|
||||
self.num_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(batch_size, q_len,
|
||||
self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
k_v_seq_len = key_states.shape[-2]
|
||||
# ipex-llm change starts
|
||||
attn_weights = torch.matmul(query_states.float(),
|
||||
key_states.float().transpose(2, 3)) * self.scale
|
||||
# ipex-llm change ends
|
||||
|
||||
if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len):
|
||||
invalidInputError(False,
|
||||
f"Attention weights should be of size ({batch_size, self.num_heads, }"
|
||||
f"{q_len, k_v_seq_len}), but is {attn_weights.size()}")
|
||||
|
||||
if attention_mask is not None:
|
||||
invalidInputError(attention_mask.size() == (batch_size, 1, q_len, k_v_seq_len),
|
||||
f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}"
|
||||
f", but is {attention_mask.size()}")
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1,
|
||||
dtype=torch.float32).to(query_states.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
||||
# ipex-llm change starts
|
||||
attn_output = torch.matmul(attn_weights.float(), value_states.float())
|
||||
# ipex-llm change ends
|
||||
|
||||
invalidInputError(attn_output.size() == (batch_size, self.num_heads, q_len, self.head_dim),
|
||||
f"`attn_output` should be of size ({batch_size, self.num_heads, }"
|
||||
f"{q_len, self.head_dim}), but is {attn_output.size()}")
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
def _in_projection_packed(
|
||||
q: Tensor,
|
||||
k: Tensor,
|
||||
v: Tensor,
|
||||
w: Tensor,
|
||||
b: Optional[Tensor] = None,
|
||||
) -> List[Tensor]:
|
||||
w_q, w_k, w_v = w.chunk(3)
|
||||
if b is None:
|
||||
b_q = b_k = b_v = None
|
||||
else:
|
||||
b_q, b_k, b_v = b.chunk(3)
|
||||
return linear(q.float(), w_q.float(), b_q.float()), \
|
||||
linear(k.float(), w_k.float(), b_k.float()), \
|
||||
linear(v.float(), w_v.float(), b_v.float())
|
||||
|
||||
|
||||
def multi_head_attn_forward(
|
||||
self,
|
||||
query: Tensor,
|
||||
key: Tensor,
|
||||
value: Tensor,
|
||||
embed_dim_to_check: int,
|
||||
num_heads: int,
|
||||
in_proj_weight: Optional[Tensor],
|
||||
in_proj_bias: Optional[Tensor],
|
||||
bias_k: Optional[Tensor],
|
||||
bias_v: Optional[Tensor],
|
||||
add_zero_attn: bool,
|
||||
dropout_p: float,
|
||||
out_proj_weight: Tensor,
|
||||
out_proj_bias: Optional[Tensor],
|
||||
training: bool = True,
|
||||
key_padding_mask: Optional[Tensor] = None,
|
||||
need_weights: bool = True,
|
||||
attn_mask: Optional[Tensor] = None,
|
||||
use_separate_proj_weight: bool = False,
|
||||
q_proj_weight: Optional[Tensor] = None,
|
||||
k_proj_weight: Optional[Tensor] = None,
|
||||
v_proj_weight: Optional[Tensor] = None,
|
||||
static_k: Optional[Tensor] = None,
|
||||
static_v: Optional[Tensor] = None,
|
||||
average_attn_weights: bool = True,
|
||||
is_causal: bool = False,
|
||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||
# port from https://huggingface.co/openbmb/MiniCPM-V-2_6/blob/main/resampler.py#L338
|
||||
# to solve conflict of fp16 and fp32 dtype
|
||||
is_batched = True if query.dim() == 3 else False
|
||||
|
||||
# For unbatched input, we unsqueeze at the expected batch-dim to pretend that the input
|
||||
# is batched, run the computation and before returning squeeze the
|
||||
# batch dimension so that the output doesn't carry this temporary batch dimension.
|
||||
if not is_batched:
|
||||
# unsqueeze if the input is unbatched
|
||||
query = query.unsqueeze(1)
|
||||
key = key.unsqueeze(1)
|
||||
value = value.unsqueeze(1)
|
||||
if key_padding_mask is not None:
|
||||
key_padding_mask = key_padding_mask.unsqueeze(0)
|
||||
|
||||
# set up shape vars
|
||||
tgt_len, bsz, embed_dim = query.shape
|
||||
src_len, _, _ = key.shape
|
||||
|
||||
if isinstance(embed_dim, torch.Tensor):
|
||||
# embed_dim can be a tensor when JIT tracing
|
||||
head_dim = embed_dim.div(num_heads, rounding_mode='trunc')
|
||||
else:
|
||||
head_dim = embed_dim // num_heads
|
||||
|
||||
# compute in-projection
|
||||
q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias)
|
||||
|
||||
# prep attention mask
|
||||
if attn_mask is not None:
|
||||
# ensure attn_mask's dim is 3
|
||||
if attn_mask.dim() == 2:
|
||||
correct_2d_size = (tgt_len, src_len)
|
||||
invalidInputError(attn_mask.shape == correct_2d_size,
|
||||
f"The shape of the 2D attn_mask is {attn_mask.shape},"
|
||||
f"but should be {correct_2d_size}.")
|
||||
attn_mask = attn_mask.unsqueeze(0)
|
||||
elif attn_mask.dim() == 3:
|
||||
correct_3d_size = (bsz * num_heads, tgt_len, src_len)
|
||||
invalidInputError(attn_mask.shape == correct_3d_size,
|
||||
f"The shape of the 3D attn_mask is {attn_mask.shape},"
|
||||
f" but should be {correct_3d_size}.")
|
||||
else:
|
||||
invalidInputError(False, f"attn_mask's dimension {attn_mask.dim()} is not supported")
|
||||
|
||||
# add bias along batch dimension (currently second)
|
||||
if bias_k is not None and bias_v is not None:
|
||||
k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
|
||||
v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
|
||||
if attn_mask is not None:
|
||||
attn_mask = pad(attn_mask, (0, 1))
|
||||
if key_padding_mask is not None:
|
||||
key_padding_mask = pad(key_padding_mask, (0, 1))
|
||||
|
||||
#
|
||||
# reshape q, k, v for multihead attention and make em batch first
|
||||
#
|
||||
q = q.view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
|
||||
if static_k is None:
|
||||
k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
|
||||
else:
|
||||
k = static_k
|
||||
if static_v is None:
|
||||
v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
|
||||
else:
|
||||
v = static_v
|
||||
|
||||
# add zero attention along batch dimension (now first)
|
||||
if add_zero_attn:
|
||||
zero_attn_shape = (bsz * num_heads, 1, head_dim)
|
||||
k = torch.cat([k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1)
|
||||
v = torch.cat([v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1)
|
||||
if attn_mask is not None:
|
||||
attn_mask = pad(attn_mask, (0, 1))
|
||||
if key_padding_mask is not None:
|
||||
key_padding_mask = pad(key_padding_mask, (0, 1))
|
||||
|
||||
# update source sequence length after adjustments
|
||||
src_len = k.size(1)
|
||||
|
||||
# merge key padding and attention masks
|
||||
if key_padding_mask is not None:
|
||||
key_padding_mask = key_padding_mask.view(bsz, 1, 1, src_len). \
|
||||
expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, src_len)
|
||||
if attn_mask is None:
|
||||
attn_mask = key_padding_mask
|
||||
else:
|
||||
attn_mask = attn_mask + key_padding_mask
|
||||
|
||||
# adjust dropout probability
|
||||
if not training:
|
||||
dropout_p = 0.0
|
||||
|
||||
# (deep breath) calculate attention and out projection
|
||||
if need_weights:
|
||||
B, Nt, E = q.shape
|
||||
q_scaled = q / math.sqrt(E)
|
||||
|
||||
if attn_mask is not None:
|
||||
attn_output_weights = torch.baddbmm(attn_mask.float(),
|
||||
q_scaled.float(), k.transpose(-2, -1))
|
||||
else:
|
||||
attn_output_weights = torch.bmm(q_scaled, k.transpose(-2, -1))
|
||||
attn_output_weights = softmax(attn_output_weights, dim=-1)
|
||||
if dropout_p > 0.0:
|
||||
attn_output_weights = dropout(attn_output_weights, p=dropout_p)
|
||||
|
||||
attn_output = torch.bmm(attn_output_weights.float(), v.float())
|
||||
|
||||
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
|
||||
attn_output = self.out_proj(attn_output)
|
||||
attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
|
||||
|
||||
# optionally average attention weights over heads
|
||||
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
|
||||
if average_attn_weights:
|
||||
attn_output_weights = attn_output_weights.mean(dim=1)
|
||||
|
||||
if not is_batched:
|
||||
# squeeze the output if input was unbatched
|
||||
attn_output = attn_output.squeeze(1)
|
||||
attn_output_weights = attn_output_weights.squeeze(0)
|
||||
return attn_output, attn_output_weights
|
||||
else:
|
||||
# attn_mask can be either (L,S) or (N*num_heads, L, S)
|
||||
# if attn_mask's shape is (1, L, S) we need to unsqueeze to (1, 1, L, S)
|
||||
# in order to match the input for SDPA of (N, num_heads, L, S)
|
||||
if attn_mask is not None:
|
||||
if attn_mask.size(0) == 1 and attn_mask.dim() == 3:
|
||||
attn_mask = attn_mask.unsqueeze(0)
|
||||
else:
|
||||
attn_mask = attn_mask.view(bsz, num_heads, -1, src_len)
|
||||
|
||||
q = q.view(bsz, num_heads, tgt_len, head_dim)
|
||||
k = k.view(bsz, num_heads, src_len, head_dim)
|
||||
v = v.view(bsz, num_heads, src_len, head_dim)
|
||||
|
||||
attn_output = F.scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
|
||||
attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
|
||||
|
||||
attn_output = self.out_proj(attn_output)
|
||||
attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
|
||||
if not is_batched:
|
||||
# squeeze the output if input was unbatched
|
||||
attn_output = attn_output.squeeze(1)
|
||||
return attn_output, None
|
||||
|
||||
|
||||
def resampler_forward(self, x, tgt_sizes=None):
|
||||
# port from https://huggingface.co/openbmb/MiniCPM-V-2_6/blob/main/resampler.py#L130
|
||||
bs = x.shape[0]
|
||||
|
||||
device = x.device
|
||||
dtype = x.dtype
|
||||
|
||||
patch_len = tgt_sizes[:, 0] * tgt_sizes[:, 1]
|
||||
|
||||
self._adjust_pos_cache(tgt_sizes, device=device)
|
||||
|
||||
max_patch_len = torch.max(patch_len)
|
||||
key_padding_mask = torch.zeros((bs, max_patch_len), dtype=torch.bool, device=device)
|
||||
|
||||
pos_embed = []
|
||||
for i in range(bs):
|
||||
tgt_h, tgt_w = tgt_sizes[i]
|
||||
pos_embed.append(self.pos_embed[:tgt_h, :tgt_w, :].reshape((tgt_h * tgt_w, -1)).to(dtype))
|
||||
key_padding_mask[i, patch_len[i]:] = True
|
||||
|
||||
pos_embed = torch.nn.utils.rnn.pad_sequence(
|
||||
pos_embed, batch_first=True, padding_value=0.0).permute(1, 0, 2) # BLD => L * B * D
|
||||
|
||||
x = self.kv_proj(x) # B * L * D
|
||||
x = self.ln_kv(x).permute(1, 0, 2) # L * B * D
|
||||
|
||||
q = self.ln_q(self.query) # Q * D
|
||||
|
||||
out = self.attn(
|
||||
self._repeat(q, bs), # Q * B * D
|
||||
x + pos_embed, # L * B * D + L * B * D
|
||||
x,
|
||||
key_padding_mask=key_padding_mask)[0]
|
||||
# out: Q * B * D
|
||||
x = out.permute(1, 0, 2) # B * Q * D
|
||||
|
||||
x = self.ln_post(x)
|
||||
# ipex-llm change starts
|
||||
x = x.float() @ self.proj.float()
|
||||
x = x.to(torch.float16)
|
||||
# ipex-llm change ends
|
||||
return x
|
||||
Loading…
Reference in a new issue