[LLM] Add quantize_kv optimization for yuan2 model (#10243)
* add initial quantize_kv support for yuan2 model * fix yuan2 quantize_kv generation * apply fp16 conv layer optimizations * disable mlp for quantize_kv
This commit is contained in:
parent
a2ed4d714e
commit
13b0bc9075
2 changed files with 184 additions and 5 deletions
|
|
@ -1196,13 +1196,14 @@ def _optimize_post(model, lightweight_bmm=False):
|
|||
modeling_module_name = model.__class__.__module__
|
||||
module = importlib.import_module(modeling_module_name)
|
||||
from bigdl.llm.transformers.models.yuan import yuan_attention_forward
|
||||
from bigdl.llm.transformers.models.yuan import yuan_mlp_forward
|
||||
# from bigdl.llm.transformers.models.yuan import yuan_mlp_forward
|
||||
convert_forward(model,
|
||||
module.YuanAttention,
|
||||
yuan_attention_forward
|
||||
)
|
||||
convert_forward(model,
|
||||
module.YuanMLP,
|
||||
yuan_mlp_forward
|
||||
)
|
||||
# disable able mlp_forward for quantize_kv on mtl.
|
||||
# convert_forward(model,
|
||||
# module.YuanMLP,
|
||||
# yuan_mlp_forward
|
||||
# )
|
||||
return model
|
||||
|
|
|
|||
|
|
@ -32,6 +32,8 @@ from bigdl.llm.utils.common import invalidInputError
|
|||
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb, \
|
||||
apply_rotary_pos_emb_cache_freq_xpu, mlp_fusion_check, fp16_fusion_check
|
||||
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
||||
from bigdl.llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_cache, \
|
||||
restore_fp8_kv_cache, use_quantize_kv_cache
|
||||
from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_31
|
||||
from bigdl.llm.transformers.low_bit_linear import SYM_INT4, FP8E5
|
||||
|
||||
|
|
@ -144,6 +146,182 @@ def yuan_attention_forward(
|
|||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
if use_quantize_kv_cache(self.merged_qk_proj, hidden_states):
|
||||
forward_function = yuan_attention_forward_quantized
|
||||
else:
|
||||
forward_function = yuan_attention_forward_origin
|
||||
return forward_function(
|
||||
self=self,
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
|
||||
|
||||
def yuan_attention_forward_quantized(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
device = hidden_states.device
|
||||
before_hidden_states = None
|
||||
is_first_step = False
|
||||
|
||||
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
||||
|
||||
invalidInputError(use_cache, "use_cache=True is needed")
|
||||
invalidInputError(not self.use_shareqk, "use_shareqk is not supported for now")
|
||||
|
||||
if past_key_value is None:
|
||||
is_first_step = True
|
||||
if q_len >= 2:
|
||||
before_hidden_states = hidden_states[:, -2:, :].transpose(0, 1).half()
|
||||
else:
|
||||
before_hidden_states = torch.zeros(2, bsz, self.hidden_size,
|
||||
dtype=torch.half, device=hidden_states.device)
|
||||
before_hidden_states[-1:, :, :] = hidden_states[:, -1:, :].transpose(0, 1)
|
||||
else:
|
||||
before_hidden_states = past_key_value[2]
|
||||
this_hidden_states = torch.cat([
|
||||
before_hidden_states,
|
||||
hidden_states.transpose(0, 1).half(),
|
||||
], dim=0)
|
||||
before_hidden_states = this_hidden_states[-2:, :, ]
|
||||
|
||||
value_states = \
|
||||
self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
if is_first_step:
|
||||
hidden_states = yuan_localized_filtering_forward(self.lf_gate, hidden_states,
|
||||
None, hidden_states.dtype)
|
||||
else:
|
||||
hidden_states = yuan_localized_filtering_forward(self.lf_gate, hidden_states,
|
||||
this_hidden_states, hidden_states.dtype)
|
||||
qk_states = self.merged_qk_proj(hidden_states)
|
||||
(query_states, key_states) = torch.chunk(qk_states, 2, dim=-1)
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value[0].shape[-2]
|
||||
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
if use_fuse_rope:
|
||||
query_states, key_states = apply_rotary_pos_emb_cache_freq_xpu(query_states,
|
||||
key_states,
|
||||
sin, cos,
|
||||
"yuan",
|
||||
position_ids)
|
||||
else:
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states,
|
||||
key_states,
|
||||
cos, sin,
|
||||
position_ids,
|
||||
"yuan")
|
||||
|
||||
if past_key_value is None:
|
||||
# should use origin attn here
|
||||
attn_weights = torch.matmul(query_states,
|
||||
key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
|
||||
invalidInputError(attn_weights.size() == (bsz, self.num_heads, q_len, kv_seq_len),
|
||||
"Attention weights should be of size "
|
||||
f"{(bsz, self.num_heads, q_len, kv_seq_len)}, "
|
||||
f"but is {attn_weights.size()}")
|
||||
|
||||
if attention_mask is not None:
|
||||
invalidInputError(attention_mask.size() == (bsz, 1, q_len, kv_seq_len),
|
||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, "
|
||||
f"but is {attention_mask.size()}")
|
||||
attn_weights = attn_weights + attention_mask
|
||||
attn_weights = torch.max(attn_weights,
|
||||
torch.tensor(torch.finfo(attn_weights.dtype).min))
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1,
|
||||
dtype=torch.float32).to(query_states.dtype)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
if use_cache:
|
||||
k_cache, v_cache = init_fp8_kv_cache(
|
||||
bsz, self.num_heads, kv_seq_len, self.head_dim, device=device
|
||||
)
|
||||
key_states, value_states = append_fp8_kv_cache(k_cache, v_cache,
|
||||
key_states, value_states)
|
||||
past_key_value = (key_states, value_states, before_hidden_states)
|
||||
|
||||
else:
|
||||
k_cache, v_cache, _ = past_key_value
|
||||
key_states, value_states = append_fp8_kv_cache(k_cache, v_cache,
|
||||
key_states, value_states)
|
||||
past_key_value = (key_states, value_states, before_hidden_states)
|
||||
|
||||
# torch.matmul
|
||||
if query_states.size(2) != 1 or device.type != 'xpu':
|
||||
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
|
||||
query_states.dtype)
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3))
|
||||
else:
|
||||
import linear_q4_0
|
||||
attn_weights = linear_q4_0.query_key_fp8_matmul(query_states, key_states)
|
||||
|
||||
attn_weights = attn_weights / math.sqrt(self.head_dim)
|
||||
|
||||
invalidInputError(attn_weights.size() == (bsz, self.num_heads, q_len, kv_seq_len),
|
||||
"Attention weights should be of size "
|
||||
f"{(bsz, self.num_heads, q_len, kv_seq_len)}, "
|
||||
f"but is {attn_weights.size()}")
|
||||
|
||||
if attention_mask is not None:
|
||||
invalidInputError(attention_mask.size() == (bsz, 1, q_len, kv_seq_len),
|
||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, "
|
||||
f"but is {attention_mask.size()}")
|
||||
attn_weights = attn_weights + attention_mask
|
||||
attn_weights = torch.max(attn_weights,
|
||||
torch.tensor(torch.finfo(attn_weights.dtype).min))
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1,
|
||||
dtype=torch.float32).to(query_states.dtype)
|
||||
if query_states.size(2) != 1 or device.type != 'xpu':
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
else:
|
||||
import linear_q4_0
|
||||
attn_output = linear_q4_0.attn_value_fp8_matmul(attn_weights,
|
||||
value_states.transpose(-1, -2))
|
||||
|
||||
invalidInputError(attn_output.size() == (bsz, self.num_heads, q_len, self.head_dim),
|
||||
"`attn_output` should be of size "
|
||||
f"{(bsz, self.num_heads, q_len, self.head_dim)}, "
|
||||
f"but is {attn_output.size()}")
|
||||
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
attn_output = self.o_proj(attn_output)
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
def yuan_attention_forward_origin(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
|
|
|||
Loading…
Reference in a new issue