Qwen2 SDPA forward on CPU (#10395)
* Fix Qwen1.5 CPU forward * Update convert.py * Update qwen2.py
This commit is contained in:
parent
ca58a69b97
commit
d72c0fad0d
2 changed files with 155 additions and 5 deletions
|
|
@ -1083,20 +1083,25 @@ def _optimize_post(model, lightweight_bmm=False):
|
|||
modeling_module_name = model.__class__.__module__
|
||||
module = importlib.import_module(modeling_module_name)
|
||||
from bigdl.llm.transformers.models.qwen2 import qwen2_model_forward
|
||||
from bigdl.llm.transformers.models.qwen2 import qwen2_attention_forward
|
||||
convert_forward(model,
|
||||
module.Qwen2Model,
|
||||
qwen2_model_forward)
|
||||
convert_forward(model,
|
||||
module.Qwen2Attention,
|
||||
qwen2_attention_forward
|
||||
)
|
||||
convert_forward(model,
|
||||
module.Qwen2RMSNorm,
|
||||
llama_rms_norm_forward)
|
||||
convert_forward(model,
|
||||
module.Qwen2MLP,
|
||||
llama_mlp_forward)
|
||||
if model.device.type == 'cpu':
|
||||
from bigdl.llm.transformers.models.qwen2 import qwen2_sdpa_attention_forward
|
||||
convert_forward(model,
|
||||
module.Qwen2SdpaAttention,
|
||||
qwen2_sdpa_attention_forward)
|
||||
else:
|
||||
from bigdl.llm.transformers.models.qwen2 import qwen2_attention_forward
|
||||
convert_forward(model,
|
||||
module.Qwen2Attention,
|
||||
qwen2_attention_forward)
|
||||
elif model.config.model_type == "aquila":
|
||||
modeling_module_name = model.__class__.__module__
|
||||
module = importlib.import_module(modeling_module_name)
|
||||
|
|
|
|||
|
|
@ -379,3 +379,148 @@ def qwen2_attention_forward_origin(
|
|||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
def qwen2_sdpa_attention_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
|
||||
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
||||
|
||||
if "padding_mask" in kwargs:
|
||||
warnings.warn(
|
||||
"Passing `padding_mask` is deprecated and will be removed in v4.37. "
|
||||
"Please make sure use `attention_mask` instead.`"
|
||||
)
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
device = hidden_states.device
|
||||
|
||||
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx)
|
||||
qtype = getattr(self.q_proj, "qtype", None)
|
||||
qtype_check = qtype in [SYM_INT4, FP8E5]
|
||||
decoding_fast_path = (qtype_check and use_fuse_rope
|
||||
and enough_kv_room and bsz * q_len == 1)
|
||||
if decoding_fast_path:
|
||||
hidden_states = hidden_states.view(1, -1)
|
||||
cache_k = past_key_value.key_cache[self.layer_idx]
|
||||
cache_v = past_key_value.value_cache[self.layer_idx]
|
||||
kv_seq_len = cache_k.shape[-2]
|
||||
import linear_q4_0
|
||||
args = [hidden_states, self.q_proj.weight, self.k_proj.weight, self.v_proj.weight,
|
||||
self.q_proj.bias, self.k_proj.bias, self.v_proj.bias, position_ids, cache_k,
|
||||
cache_v, self.q_proj.weight.qtype, self.v_proj.weight.qtype, kv_seq_len,
|
||||
self.head_dim, self.rotary_emb.base]
|
||||
query_states, key_states, value_states = linear_q4_0.forward_qkv_bias(*args)
|
||||
kv_seq_len += 1
|
||||
if self.layer_idx == 0:
|
||||
past_key_value.seen_tokens = kv_seq_len
|
||||
past_key_value.key_cache[self.layer_idx] = key_states
|
||||
past_key_value.value_cache[self.layer_idx] = value_states
|
||||
|
||||
else:
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = \
|
||||
key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = \
|
||||
value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
if self.layer_idx is None:
|
||||
invalidInputError(
|
||||
False,
|
||||
"The cache structure has changed since version v4.36. "
|
||||
f"If you are using {self.__class__.__name__} "
|
||||
"for auto-regressive decoding with k/v caching, "
|
||||
"please make sure to initialize the attention class with a layer index."
|
||||
)
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
if use_fuse_rope:
|
||||
query_states, key_states = apply_rotary_pos_emb_cache_freq_xpu(query_states, key_states,
|
||||
sin, cos, "qwen2",
|
||||
position_ids)
|
||||
else:
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
|
||||
cos, sin, position_ids)
|
||||
|
||||
if past_key_value is not None:
|
||||
# update the number of seen tokens
|
||||
if self.layer_idx == 0:
|
||||
past_key_value.seen_tokens += key_states.shape[-2]
|
||||
|
||||
if len(past_key_value.key_cache) <= self.layer_idx:
|
||||
past_key_value.key_cache.append(key_states)
|
||||
past_key_value.value_cache.append(value_states)
|
||||
else:
|
||||
cache_k = past_key_value.key_cache[self.layer_idx]
|
||||
cache_v = past_key_value.value_cache[self.layer_idx]
|
||||
|
||||
if not enough_kv_room:
|
||||
# allocate new
|
||||
new_c_k, new_c_v = extend_kv_cache(bsz,
|
||||
self.num_key_value_heads, # Support GQA
|
||||
self.head_dim,
|
||||
cache_k.size(2),
|
||||
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
|
||||
dtype=cache_k.dtype,
|
||||
device=device)
|
||||
|
||||
new_c_k[:] = cache_k
|
||||
new_c_v[:] = cache_v
|
||||
cache_k = new_c_k
|
||||
cache_v = new_c_v
|
||||
|
||||
key_states, value_states = append_kv_cache(cache_k,
|
||||
cache_v,
|
||||
key_states,
|
||||
value_states)
|
||||
|
||||
# update past_key_value
|
||||
past_key_value.key_cache[self.layer_idx] = key_states
|
||||
past_key_value.value_cache[self.layer_idx] = value_states
|
||||
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
|
||||
invalidInputError(attn_weights.size() == (bsz, self.num_heads, q_len, kv_seq_len),
|
||||
("Attention weights should be of size "
|
||||
f"{(bsz, self.num_heads, q_len, kv_seq_len)},"
|
||||
"but is {attn_weights.size()}"))
|
||||
|
||||
if attention_mask is not None:
|
||||
invalidInputError(attention_mask.size() == (bsz, 1, q_len, kv_seq_len),
|
||||
(f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}"
|
||||
f" but is {attention_mask.size()}"))
|
||||
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
from torch.nn.functional import scaled_dot_product_attention as sdpa
|
||||
attn_output = sdpa(query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=attention_mask,
|
||||
dropout_p=self.attention_dropout if self.training else 0.0,
|
||||
is_causal=self.is_causal and attention_mask is None and q_len > 1)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output, None, past_key_value
|
||||
|
|
|
|||
Loading…
Reference in a new issue