Add Qwen register attention implemention (#11110)

* qwen_register
This commit is contained in:
Zhao Changmin 2024-05-23 17:17:45 +08:00 committed by GitHub
parent 0e53f20edb
commit c5e8b90c8d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 139 additions and 6 deletions

View file

@ -1243,13 +1243,20 @@ def _optimize_post(model, lightweight_bmm=False):
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)
from ipex_llm.transformers.models.qwen import qwen_attention_forward
from ipex_llm.transformers.models.qwen import qwen_attention_forward_registered
from ipex_llm.transformers.models.qwen import qwen_mlp_forward
from ipex_llm.transformers.models.chatglm2 import chatglm_rms_norm_forward
from ipex_llm.transformers.models.qwen import qwen_model_forward
convert_forward(model,
module.QWenAttention,
qwen_attention_forward
)
if model.config.max_position_embeddings == 8192:
convert_forward(model,
module.QWenAttention,
qwen_attention_forward_registered
)
else:
convert_forward(model,
module.QWenAttention,
qwen_attention_forward
)
convert_forward(model,
module.RMSNorm,
chatglm_rms_norm_forward)
@ -1513,7 +1520,7 @@ def _optimize_post(model, lightweight_bmm=False):
from ipex_llm.transformers.models.starcoder2 import model_forward
convert_forward(model, module.Starcoder2Attention, attention_forward)
convert_forward(model, module.Starcoder2Model, model_forward)
elif model.config.model_type == 'phi':
elif model.config.model_type in ["phi3", "phi3_v"]:
# for phi-2
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)
@ -1521,7 +1528,7 @@ def _optimize_post(model, lightweight_bmm=False):
from ipex_llm.transformers.models.phi import model_forward
convert_forward(model, module.PhiAttention, attention_forward)
convert_forward(model, module.PhiModel, model_forward)
elif model.config.model_type in ["phi3", "phi3_v"]:
elif model.config.model_type == "phi3":
# for phi-3
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)

View file

@ -180,6 +180,132 @@ def qwen_attention_forward(
return attn_output, past_key_value
def qwen_attention_forward_registered(
self,
hidden_states: Optional[Tuple[torch.FloatTensor]],
rotary_pos_emb_list: Optional[List[torch.Tensor]] = None,
registered_causal_mask: Optional[torch.Tensor] = None,
layer_past: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# invalidInputError(not self.use_flash_attn and not self.use_cache_quantization,
# "flash attn and kv_cache quantization are not supported")
bsz, q_len, _ = hidden_states.size()
device = hidden_states.device
past_key_value = (None if layer_past is None
else (layer_past[0].transpose(1, 2), layer_past[1].transpose(1, 2)))
qkv = self.c_attn(hidden_states)
qkv = qkv.view(bsz, q_len, self.num_heads * 3, self.head_dim)
qkv = qkv.transpose(1, 2)
query_states, key_states, value_states = qkv.split([self.num_heads,
self.num_heads,
self.num_heads], dim=1)
kv_seq_len = key_states.shape[2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[2]
# IPEX-LLM OPT: fuse rope
position_ids = rotary_pos_emb_list[-1] # the last one is posisiton_ids
inv_freq = rotary_pos_emb_list[-2]
rotary_pos_emb_list = rotary_pos_emb_list[:-2]
invalidInputError(len(rotary_pos_emb_list) == 1,
"rotary_pos_emb_list's length cannot be larger than 1")
use_fuse_rope = should_use_fuse_rope(hidden_states, position_ids, self.training)
rotary_pos_emb = rotary_pos_emb_list[0]
if use_fuse_rope:
rot_dim = rotary_pos_emb[0].size(-1)
import linear_q4_0
linear_q4_0.rotary_half_inplaced(inv_freq, position_ids,
query_states[..., :rot_dim], key_states[..., :rot_dim])
else:
rotary_pos_emb = [i[:, -q_len:, :, :].transpose(1, 2) for i in rotary_pos_emb]
query_states = apply_rotary_pos_emb(query_states, rotary_pos_emb)
key_states = apply_rotary_pos_emb(key_states, rotary_pos_emb)
if kv_seq_len > self.seq_length and self.use_logn_attn and not self.training:
seq_start = kv_seq_len - q_len
seq_end = kv_seq_len
logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :].transpose(1, 2)
query_states = query_states * logn_tensor.type_as(query_states).expand_as(query_states)
# IPEX-LLM OPT: kv cache and quantzie kv cache
use_quantize_kv = use_quantize_kv_cache(self.c_attn, hidden_states)
key_states, value_states = update_past_key_value(
past_key_value, key_states, value_states,
kv_seq_len, use_quantize_kv, device
)
past_key_value = (key_states.transpose(1, 2),
value_states.transpose(1, 2)) if use_cache else None
# IPEX-LLM OPT: sdp
attn_weights = None
if not self.training and not hidden_states.requires_grad and \
use_flash_attention(query_states, key_states, attention_mask):
attn_output = F.scaled_dot_product_attention(query_states.to(dtype=torch.float16),
key_states.to(dtype=torch.float16),
value_states.to(dtype=torch.float16),
is_causal=True).to(hidden_states.dtype)
elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training):
import linear_q4_0
if use_quantize_kv:
attn_output = linear_q4_0.sdp_fp8_causal(query_states, key_states, value_states)
else:
attn_output = linear_q4_0.sdp_causal(query_states, key_states, value_states)
else:
if q_len > 1:
causal_mask = registered_causal_mask[
:, :, kv_seq_len - q_len:kv_seq_len, :kv_seq_len
]
attention_mask = torch.zeros(causal_mask.shape, dtype=query_states.dtype,
device=query_states.device)
attention_mask.masked_fill_(causal_mask.logical_not(),
torch.finfo(attention_mask.dtype).min)
attention_mask = attention_mask.expand([bsz, -1, -1, -1])
else:
attention_mask = None
if use_sdp(q_len, kv_seq_len, self.head_dim, query_states):
import linear_q4_0
if use_quantize_kv:
attn_output = linear_q4_0.sdp_fp8(query_states, key_states, value_states,
attention_mask)
else:
attn_output = linear_q4_0.sdp(query_states, key_states, value_states,
attention_mask)
else:
if use_quantize_kv:
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
query_states.dtype)
attn_weights = torch.matmul(query_states,
key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
if self.softmax_in_fp32:
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1,
dtype=torch.float32).to(
value_states.dtype)
else:
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
attn_output = self.c_proj(attn_output)
if output_attentions:
return attn_output, past_key_value, attn_weights
else:
return attn_output, past_key_value
def qwen_mlp_forward(self, x: torch.Tensor) -> torch.Tensor:
x_2d = x.view(-1, x.shape[-1])
qtype = getattr(self.w1, "qtype", None)