Qwen fused qkv (#10368)
* fused qkv + rope for qwen * quantized kv cache * fix * update qwen * fixed quantized qkv * fix * meet code review * update split * convert.py * extend when no enough kv * fix
This commit is contained in:
parent
741c2bf1df
commit
28c4a8cf5c
2 changed files with 346 additions and 84 deletions
|
|
@ -587,6 +587,41 @@ def _optimize_pre(model):
|
||||||
):
|
):
|
||||||
from bigdl.llm.transformers.models.bert import merge_qkv
|
from bigdl.llm.transformers.models.bert import merge_qkv
|
||||||
model.apply(merge_qkv)
|
model.apply(merge_qkv)
|
||||||
|
if model.config.model_type == "qwen":
|
||||||
|
position_ids = torch.arange(0, model.config.max_position_embeddings)
|
||||||
|
rope_base = model.config.rotary_emb_base
|
||||||
|
from accelerate.big_modeling import init_empty_weights
|
||||||
|
|
||||||
|
def split_qkv_proj_func(module):
|
||||||
|
if "QWenAttention" in module.__class__.__name__:
|
||||||
|
c_attn_weight = module.c_attn.weight.data
|
||||||
|
c_attn_bias = module.c_attn.bias.data
|
||||||
|
projection_size = module.projection_size
|
||||||
|
hid_size = module.hidden_size
|
||||||
|
with init_empty_weights():
|
||||||
|
q_proj = torch.nn.Linear(hid_size, projection_size)
|
||||||
|
k_proj = torch.nn.Linear(hid_size, projection_size)
|
||||||
|
v_proj = torch.nn.Linear(hid_size, projection_size)
|
||||||
|
if not model.config.to_dict().get("bigdl_transformers_low_bit", False):
|
||||||
|
q_proj.weight = torch.nn.Parameter(
|
||||||
|
c_attn_weight[:projection_size, :], requires_grad=False)
|
||||||
|
q_proj.bias = torch.nn.Parameter(
|
||||||
|
c_attn_bias[:projection_size], requires_grad=False)
|
||||||
|
k_proj.weight = torch.nn.Parameter(
|
||||||
|
c_attn_weight[projection_size: 2 * projection_size, :], requires_grad=False)
|
||||||
|
k_proj.bias = torch.nn.Parameter(
|
||||||
|
c_attn_bias[projection_size: 2 * projection_size], requires_grad=False)
|
||||||
|
v_proj.weight = torch.nn.Parameter(
|
||||||
|
c_attn_weight[2 * projection_size:, :], requires_grad=False)
|
||||||
|
v_proj.bias = torch.nn.Parameter(
|
||||||
|
c_attn_bias[2 * projection_size:], requires_grad=False)
|
||||||
|
module.q_proj = q_proj
|
||||||
|
module.k_proj = k_proj
|
||||||
|
module.v_proj = v_proj
|
||||||
|
module.position_ids = position_ids
|
||||||
|
module.rope_base = rope_base
|
||||||
|
del module.c_attn
|
||||||
|
model.apply(split_qkv_proj_func)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -66,6 +66,25 @@ def apply_rotary_pos_emb(t, freqs):
|
||||||
return torch.cat((t_, t_pass_), dim=-1).type_as(t)
|
return torch.cat((t_, t_pass_), dim=-1).type_as(t)
|
||||||
|
|
||||||
|
|
||||||
|
def should_use_fuse_rope(self, query_states):
|
||||||
|
use_fuse_rope = query_states.device.type == "xpu"
|
||||||
|
use_fuse_rope = use_fuse_rope and not (self.training and query_states.requires_grad)
|
||||||
|
return use_fuse_rope
|
||||||
|
|
||||||
|
|
||||||
|
def is_enough_kv_cache_room(layer_past, kv_seq_len=1):
|
||||||
|
# to determinate if is enough kv cache room in transformers between 4.31 and 4.35
|
||||||
|
# seq_len for current seq len
|
||||||
|
# For llama like kv cache, i.e., [bs, n_head, seq_len, head_dim]
|
||||||
|
if layer_past is None:
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
cache_k, cache_v = layer_past[0], layer_past[1]
|
||||||
|
cache_k = cache_k.transpose(1, 2)
|
||||||
|
cache_v = cache_v.transpose(1, 2)
|
||||||
|
return cache_k.stride(1) < (kv_seq_len + 1) * cache_k.size(3)
|
||||||
|
|
||||||
|
|
||||||
def qwen_attention_forward(
|
def qwen_attention_forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
||||||
|
|
@ -77,20 +96,87 @@ def qwen_attention_forward(
|
||||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
output_attentions: Optional[bool] = False,
|
output_attentions: Optional[bool] = False,
|
||||||
use_cache: Optional[bool] = False,
|
use_cache: Optional[bool] = False,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
if use_quantize_kv_cache(self.q_proj, hidden_states):
|
||||||
|
forward_function = qwen_attention_forward_quantized
|
||||||
|
else:
|
||||||
|
forward_function = qwen_attention_forward_original
|
||||||
|
return forward_function(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
rotary_pos_emb_list,
|
||||||
|
layer_past,
|
||||||
|
attention_mask,
|
||||||
|
head_mask,
|
||||||
|
encoder_hidden_states,
|
||||||
|
encoder_attention_mask,
|
||||||
|
output_attentions,
|
||||||
|
use_cache,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def qwen_attention_forward_original(
|
||||||
|
self,
|
||||||
|
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
||||||
|
rotary_pos_emb_list: Optional[List[torch.Tensor]] = None,
|
||||||
|
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||||
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
output_attentions: Optional[bool] = False,
|
||||||
|
use_cache: Optional[bool] = False,
|
||||||
):
|
):
|
||||||
invalidInputError(not self.use_flash_attn and not self.use_cache_quantization,
|
invalidInputError(not self.use_flash_attn and not self.use_cache_quantization,
|
||||||
"flash attn and kv_cache quantization are not supported")
|
"flash attn and kv_cache quantization are not supported")
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
device = hidden_states.device
|
||||||
|
|
||||||
mixed_x_layer = self.c_attn(hidden_states)
|
use_fuse_rope = should_use_fuse_rope(self, hidden_states)
|
||||||
query, key, value = mixed_x_layer.split(self.split_size, dim=2)
|
decoding_fast_path = (use_fuse_rope and bsz * q_len == 1)
|
||||||
|
if decoding_fast_path:
|
||||||
|
hidden_states = hidden_states.view(1, -1)
|
||||||
|
cache_k, cache_v = layer_past[0], layer_past[1]
|
||||||
|
cache_k = cache_k.transpose(1, 2)
|
||||||
|
cache_v = cache_v.transpose(1, 2)
|
||||||
|
|
||||||
query = self._split_heads(query, self.num_heads, self.head_dim)
|
kv_seq_len = cache_k.shape[-2]
|
||||||
key = self._split_heads(key, self.num_heads, self.head_dim)
|
self.position_ids = self.position_ids.to(device)
|
||||||
value = self._split_heads(value, self.num_heads, self.head_dim)
|
position_ids = self.position_ids[kv_seq_len]
|
||||||
# query, key, value's shape: [bs, seq_len, num_heads, head_dim]
|
base = self.rope_base
|
||||||
|
if is_enough_kv_cache_room(layer_past, kv_seq_len):
|
||||||
|
new_cache_k, new_cache_v = extend_kv_cache(bsz,
|
||||||
|
self.num_heads,
|
||||||
|
self.head_dim,
|
||||||
|
cache_k.size(2),
|
||||||
|
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
|
||||||
|
dtype=cache_k.dtype,
|
||||||
|
device=hidden_states.device)
|
||||||
|
new_cache_k[:] = cache_k
|
||||||
|
new_cache_v[:] = cache_v
|
||||||
|
cache_k = new_cache_k
|
||||||
|
cache_v = new_cache_v
|
||||||
|
|
||||||
|
args = [hidden_states, self.q_proj.weight.data, self.k_proj.weight.data,
|
||||||
|
self.v_proj.weight.data, self.q_proj.bias.data, self.k_proj.bias.data,
|
||||||
|
self.v_proj.bias.data, position_ids, cache_k, cache_v, self.q_proj.weight.qtype,
|
||||||
|
self.v_proj.weight.qtype, kv_seq_len, self.head_dim, base]
|
||||||
|
import linear_q4_0
|
||||||
|
query, key, value = linear_q4_0.forward_qkv_bias(*args)
|
||||||
|
kv_seq_len += 1
|
||||||
|
query_size, key_size = 1, 1
|
||||||
|
else:
|
||||||
|
query = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
|
||||||
|
key = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
|
||||||
|
value = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
|
||||||
|
# TODO: speed up
|
||||||
|
# mixed_x_layer = self.c_attn(hidden_states)
|
||||||
|
# query, key, value = mixed_x_layer.split(self.split_size, dim=2)
|
||||||
|
|
||||||
|
# query = self._split_heads(query, self.num_heads, self.head_dim)
|
||||||
|
# key = self._split_heads(key, self.num_heads, self.head_dim)
|
||||||
|
# value = self._split_heads(value, self.num_heads, self.head_dim)
|
||||||
if rotary_pos_emb_list is not None:
|
if rotary_pos_emb_list is not None:
|
||||||
use_fuse_rope = query.device.type == "xpu" and not (self.training and query.requires_grad)
|
|
||||||
cur_len = query.shape[1]
|
cur_len = query.shape[1]
|
||||||
if len(rotary_pos_emb_list) == 1:
|
if len(rotary_pos_emb_list) == 1:
|
||||||
rotary_pos_emb = rotary_pos_emb_list[0]
|
rotary_pos_emb = rotary_pos_emb_list[0]
|
||||||
|
|
@ -115,7 +201,8 @@ def qwen_attention_forward(
|
||||||
cos, sin = rotary_pos_emb
|
cos, sin = rotary_pos_emb
|
||||||
cos = cos.to(query.dtype)
|
cos = cos.to(query.dtype)
|
||||||
sin = sin.to(query.dtype)
|
sin = sin.to(query.dtype)
|
||||||
query, key = apply_rotary_pos_emb_cache_freq_xpu(query, key, sin, cos, "qwen")
|
query, key = apply_rotary_pos_emb_cache_freq_xpu(query, key,
|
||||||
|
sin, cos, "qwen")
|
||||||
query_list += [query]
|
query_list += [query]
|
||||||
key_list += [key]
|
key_list += [key]
|
||||||
else:
|
else:
|
||||||
|
|
@ -126,7 +213,6 @@ def qwen_attention_forward(
|
||||||
key_list += [apply_rotary_pos_emb(key[i:i+1, :, :], k_pos_emb)]
|
key_list += [apply_rotary_pos_emb(key[i:i+1, :, :], k_pos_emb)]
|
||||||
query = torch.cat(query_list, dim=0)
|
query = torch.cat(query_list, dim=0)
|
||||||
key = torch.cat(key_list, dim=0)
|
key = torch.cat(key_list, dim=0)
|
||||||
|
|
||||||
query_size, key_size = query.size(1), key.size(1)
|
query_size, key_size = query.size(1), key.size(1)
|
||||||
kv_seq_len = key_size if layer_past is None else key_size + layer_past[0].size(1)
|
kv_seq_len = key_size if layer_past is None else key_size + layer_past[0].size(1)
|
||||||
|
|
||||||
|
|
@ -146,42 +232,12 @@ def qwen_attention_forward(
|
||||||
else:
|
else:
|
||||||
causal_mask = None
|
causal_mask = None
|
||||||
|
|
||||||
if use_quantize_kv_cache(self.c_attn, hidden_states):
|
|
||||||
query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2)
|
|
||||||
# query, key, value's shape: [bs, num_heads, seq_len, head_dim]
|
|
||||||
|
|
||||||
if layer_past is None:
|
|
||||||
# For first token, use original attn
|
|
||||||
attn_output, attn_weight = self._attn(
|
|
||||||
query, key, value, causal_mask, attention_mask, head_mask
|
|
||||||
)
|
|
||||||
if use_cache:
|
|
||||||
max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
|
|
||||||
k_cache, v_cache = init_fp8_kv_cache(
|
|
||||||
query.size(0), self.num_heads, kv_seq_len, self.head_dim,
|
|
||||||
device=query.device,
|
|
||||||
)
|
|
||||||
key, value = append_fp8_kv_cache(k_cache, v_cache, key, value)
|
|
||||||
else:
|
|
||||||
k_cache, v_cache = layer_past[0], layer_past[1]
|
|
||||||
k_cache = k_cache.transpose(1, 2)
|
|
||||||
v_cache = v_cache.transpose(1, 2)
|
|
||||||
# k_cache and v_cache's shape: [bs, num_heads, context_length, head_dim]
|
|
||||||
|
|
||||||
key, value = append_fp8_kv_cache(k_cache, v_cache, key, value)
|
|
||||||
|
|
||||||
attn_output, attn_weight = core_attn(
|
|
||||||
self, query, key, value, causal_mask, attention_mask, head_mask
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
bsz = key.size(0)
|
|
||||||
if layer_past is not None:
|
if layer_past is not None:
|
||||||
|
if not decoding_fast_path:
|
||||||
cache_k, cache_v = layer_past[0], layer_past[1]
|
cache_k, cache_v = layer_past[0], layer_past[1]
|
||||||
cache_k = cache_k.transpose(1, 2)
|
cache_k = cache_k.transpose(1, 2)
|
||||||
cache_v = cache_v.transpose(1, 2)
|
cache_v = cache_v.transpose(1, 2)
|
||||||
if cache_k.stride(1) < kv_seq_len * cache_k.size(3):
|
if cache_k.stride(1) < kv_seq_len * cache_k.size(3):
|
||||||
# allocate new
|
|
||||||
new_cache_k, new_cache_v = extend_kv_cache(bsz,
|
new_cache_k, new_cache_v = extend_kv_cache(bsz,
|
||||||
self.num_heads,
|
self.num_heads,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
|
|
@ -193,7 +249,6 @@ def qwen_attention_forward(
|
||||||
new_cache_v[:] = cache_v
|
new_cache_v[:] = cache_v
|
||||||
cache_k = new_cache_k
|
cache_k = new_cache_k
|
||||||
cache_v = new_cache_v
|
cache_v = new_cache_v
|
||||||
|
|
||||||
key_states, value_states = append_kv_cache(cache_k, cache_v,
|
key_states, value_states = append_kv_cache(cache_k, cache_v,
|
||||||
key.transpose(1, 2), value.transpose(1, 2))
|
key.transpose(1, 2), value.transpose(1, 2))
|
||||||
key = key_states
|
key = key_states
|
||||||
|
|
@ -212,6 +267,7 @@ def qwen_attention_forward(
|
||||||
key = new_key_states
|
key = new_key_states
|
||||||
value = new_value_states
|
value = new_value_states
|
||||||
|
|
||||||
|
if not decoding_fast_path:
|
||||||
query = query.transpose(1, 2)
|
query = query.transpose(1, 2)
|
||||||
|
|
||||||
attn_output, attn_weight = self._attn(
|
attn_output, attn_weight = self._attn(
|
||||||
|
|
@ -234,6 +290,177 @@ def qwen_attention_forward(
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
def qwen_attention_forward_quantized(
|
||||||
|
self,
|
||||||
|
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
||||||
|
rotary_pos_emb_list: Optional[List[torch.Tensor]] = None,
|
||||||
|
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||||
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
output_attentions: Optional[bool] = False,
|
||||||
|
use_cache: Optional[bool] = False,
|
||||||
|
):
|
||||||
|
invalidInputError(not self.use_flash_attn and not self.use_cache_quantization,
|
||||||
|
"flash attn and kv_cache quantization are not supported")
|
||||||
|
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
device = hidden_states.device
|
||||||
|
|
||||||
|
use_fuse_rope = should_use_fuse_rope(self, hidden_states)
|
||||||
|
# TODO: use when decoding_fast_path = (use_fuse_rope and bsz * q_len == 1)
|
||||||
|
decoding_fast_path = False
|
||||||
|
if decoding_fast_path:
|
||||||
|
hidden_states = hidden_states.view(1, -1)
|
||||||
|
tmp_cache_k, tmp_cache_v = init_kv_cache(
|
||||||
|
bsz,
|
||||||
|
self.num_heads,
|
||||||
|
self.head_dim,
|
||||||
|
0,
|
||||||
|
1,
|
||||||
|
dtype=hidden_states.dtype,
|
||||||
|
device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
position_ids = self.position_ids[self.kv_seq_len].to(device)
|
||||||
|
base = self.rope_base
|
||||||
|
|
||||||
|
args = [hidden_states, self.q_proj.weight.data, self.k_proj.weight.data,
|
||||||
|
self.v_proj.weight.data, self.q_proj.bias.data, self.k_proj.bias.data,
|
||||||
|
self.v_proj.bias.data, position_ids, tmp_cache_k, tmp_cache_v,
|
||||||
|
self.q_proj.weight.qtype, self.v_proj.weight.qtype, 0, self.head_dim, base]
|
||||||
|
import linear_q4_0
|
||||||
|
query, key, value = linear_q4_0.forward_qkv_bias(*args)
|
||||||
|
self.kv_seq_len += 1
|
||||||
|
kv_seq_len = self.kv_seq_len
|
||||||
|
query_size, key_size = 1, 1
|
||||||
|
else:
|
||||||
|
query = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
|
||||||
|
key = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
|
||||||
|
value = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
|
||||||
|
# TODO: speed up
|
||||||
|
# mixed_x_layer = self.c_attn(hidden_states)
|
||||||
|
# query, key, value = mixed_x_layer.split(self.split_size, dim=2)
|
||||||
|
|
||||||
|
# query = self._split_heads(query, self.num_heads, self.head_dim)
|
||||||
|
# key = self._split_heads(key, self.num_heads, self.head_dim)
|
||||||
|
# value = self._split_heads(value, self.num_heads, self.head_dim)
|
||||||
|
if rotary_pos_emb_list is not None:
|
||||||
|
cur_len = query.shape[1]
|
||||||
|
if len(rotary_pos_emb_list) == 1:
|
||||||
|
rotary_pos_emb = rotary_pos_emb_list[0]
|
||||||
|
rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb]
|
||||||
|
if use_fuse_rope:
|
||||||
|
cos, sin = rotary_pos_emb
|
||||||
|
cos = cos.to(query.dtype)
|
||||||
|
sin = sin.to(query.dtype)
|
||||||
|
query, key = apply_rotary_pos_emb_cache_freq_xpu(query, key, sin, cos, "qwen")
|
||||||
|
else:
|
||||||
|
rotary_pos_emb = (rotary_pos_emb,) * 2
|
||||||
|
q_pos_emb, k_pos_emb = rotary_pos_emb
|
||||||
|
# Slice the pos emb for current inference
|
||||||
|
query = apply_rotary_pos_emb(query, q_pos_emb)
|
||||||
|
key = apply_rotary_pos_emb(key, k_pos_emb)
|
||||||
|
else:
|
||||||
|
query_list = []
|
||||||
|
key_list = []
|
||||||
|
for i, rotary_pos_emb in enumerate(rotary_pos_emb_list):
|
||||||
|
rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb]
|
||||||
|
if use_fuse_rope:
|
||||||
|
cos, sin = rotary_pos_emb
|
||||||
|
cos = cos.to(query.dtype)
|
||||||
|
sin = sin.to(query.dtype)
|
||||||
|
query, key = apply_rotary_pos_emb_cache_freq_xpu(query, key,
|
||||||
|
sin, cos, "qwen")
|
||||||
|
query_list += [query]
|
||||||
|
key_list += [key]
|
||||||
|
else:
|
||||||
|
rotary_pos_emb = (rotary_pos_emb,) * 2
|
||||||
|
q_pos_emb, k_pos_emb = rotary_pos_emb
|
||||||
|
# Slice the pos emb for current inference
|
||||||
|
query_list += [apply_rotary_pos_emb(query[i:i+1, :, :], q_pos_emb)]
|
||||||
|
key_list += [apply_rotary_pos_emb(key[i:i+1, :, :], k_pos_emb)]
|
||||||
|
query = torch.cat(query_list, dim=0)
|
||||||
|
key = torch.cat(key_list, dim=0)
|
||||||
|
query_size, key_size = query.size(1), key.size(1)
|
||||||
|
kv_seq_len = key_size if layer_past is None else key_size + layer_past[0].size(1)
|
||||||
|
|
||||||
|
if kv_seq_len > self.seq_length and self.use_logn_attn and not self.training:
|
||||||
|
seq_start = kv_seq_len - query_size
|
||||||
|
seq_end = kv_seq_len
|
||||||
|
logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :].type_as(query)
|
||||||
|
query = query * logn_tensor.expand_as(query)
|
||||||
|
|
||||||
|
if query_size > 1:
|
||||||
|
causal_mask = torch.tril(
|
||||||
|
torch.ones((kv_seq_len, kv_seq_len), dtype=torch.bool, device=query.device)
|
||||||
|
).view(1, 1, kv_seq_len, kv_seq_len)
|
||||||
|
causal_mask = causal_mask[
|
||||||
|
:, :, kv_seq_len - query_size:kv_seq_len, :kv_seq_len
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
causal_mask = None
|
||||||
|
|
||||||
|
if layer_past is None:
|
||||||
|
query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2)
|
||||||
|
# query, key, value's shape: [bs, num_heads, seq_len, head_dim]
|
||||||
|
|
||||||
|
# save kv seq len for decoding_fast_path
|
||||||
|
self.kv_seq_len = key.shape[-2]
|
||||||
|
# For first token, use original attn
|
||||||
|
attn_output, attn_weight = self._attn(
|
||||||
|
query, key, value, causal_mask, attention_mask, head_mask
|
||||||
|
)
|
||||||
|
if use_cache:
|
||||||
|
max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
|
||||||
|
k_cache, v_cache = init_fp8_kv_cache(
|
||||||
|
query.size(0), self.num_heads, kv_seq_len, self.head_dim,
|
||||||
|
device=query.device,
|
||||||
|
)
|
||||||
|
key, value = append_fp8_kv_cache(k_cache, v_cache, key, value)
|
||||||
|
else:
|
||||||
|
if decoding_fast_path:
|
||||||
|
k_cache, v_cache = layer_past[0], layer_past[1]
|
||||||
|
k_cache = k_cache.transpose(1, 2)
|
||||||
|
v_cache = v_cache.transpose(1, 2)
|
||||||
|
# k_cache and v_cache's shape: [bs, num_heads, context_length, head_dim]
|
||||||
|
|
||||||
|
key, value = append_fp8_kv_cache(k_cache, v_cache, key, value)
|
||||||
|
|
||||||
|
attn_output, attn_weight = core_attn(
|
||||||
|
self, query, key, value, causal_mask, attention_mask, head_mask
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2)
|
||||||
|
k_cache, v_cache = layer_past[0], layer_past[1]
|
||||||
|
k_cache = k_cache.transpose(1, 2)
|
||||||
|
v_cache = v_cache.transpose(1, 2)
|
||||||
|
# k_cache and v_cache's shape: [bs, num_heads, context_length, head_dim]
|
||||||
|
|
||||||
|
key, value = append_fp8_kv_cache(k_cache, v_cache, key, value)
|
||||||
|
|
||||||
|
attn_output, attn_weight = core_attn(
|
||||||
|
self, query, key, value, causal_mask, attention_mask, head_mask
|
||||||
|
)
|
||||||
|
|
||||||
|
context_layer = self._merge_heads(
|
||||||
|
attn_output, self.num_heads, self.head_dim
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = self.c_proj(context_layer)
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
outputs = (attn_output, (key.transpose(1, 2), value.transpose(1, 2)))
|
||||||
|
else:
|
||||||
|
outputs = (attn_output, None)
|
||||||
|
if output_attentions:
|
||||||
|
outputs += (attn_weight,)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
def core_attn(self, query, key, value, causal_mask=None, attention_mask=None, head_mask=None):
|
def core_attn(self, query, key, value, causal_mask=None, attention_mask=None, head_mask=None):
|
||||||
if query.size(2) != 1 or query.device.type != 'xpu':
|
if query.size(2) != 1 or query.device.type != 'xpu':
|
||||||
# We have no CPU fp8 matmul implementation for now, so just upscale to fp32
|
# We have no CPU fp8 matmul implementation for now, so just upscale to fp32
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue