disable default quantize_kv of GQA on MTL (#11679)

* disable default quantizekv of gqa in mtl

* fix stype

* fix stype

* fix stype

* fix stype

* fix stype

* fix stype
This commit is contained in:
hxsz1997 2024-07-30 04:38:46 +03:00 committed by GitHub
parent c02003925b
commit 9b36877897
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 25 additions and 18 deletions

View file

@ -118,7 +118,8 @@ def llama_model_forward_4_36(
use_cache = use_cache if use_cache is not None else self.config.use_cache
input = input_ids if input_ids is not None else inputs_embeds
if use_cache:
if use_quantize_kv_cache(self.layers[0].mlp.up_proj, input):
if use_quantize_kv_cache(self.layers[0].mlp.up_proj, input,
self.config.num_attention_heads//self.config.num_key_value_heads):
if not isinstance(past_key_values, DynamicFp8Cache):
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
elif should_use_compresskv(input):
@ -157,7 +158,8 @@ def llama_model_forward_4_38(
use_cache = use_cache if use_cache is not None else self.config.use_cache
input = input_ids if input_ids is not None else inputs_embeds
if use_cache:
if use_quantize_kv_cache(self.layers[0].mlp.up_proj, input):
if use_quantize_kv_cache(self.layers[0].mlp.up_proj, input,
self.config.num_attention_heads//self.config.num_key_value_heads):
if not isinstance(past_key_values, DynamicFp8Cache):
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
elif should_use_compresskv(input):
@ -197,7 +199,8 @@ def llama_model_forward_4_41(
use_cache = use_cache if use_cache is not None else self.config.use_cache
input = input_ids if input_ids is not None else inputs_embeds
if use_cache:
if use_quantize_kv_cache(self.layers[0].mlp.up_proj, input):
if use_quantize_kv_cache(self.layers[0].mlp.up_proj, input,
self.config.num_attention_heads//self.config.num_key_value_heads):
if not isinstance(past_key_values, DynamicFp8Cache):
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
elif should_use_compresskv(input):
@ -425,7 +428,7 @@ def llama_attention_forward_4_31(
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if use_quantize_kv_cache(self.q_proj, hidden_states):
if use_quantize_kv_cache(self.q_proj, hidden_states, self.num_key_value_groups):
forward_function = llama_attention_forward_4_31_quantized
else:
forward_function = llama_attention_forward_4_31_original
@ -1027,7 +1030,7 @@ def llama_attention_forward_4_41(
cache_position: Optional[torch.LongTensor] = None,
**kwargs
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.FloatTensor]]]:
if use_quantize_kv_cache(self.q_proj, hidden_states):
if use_quantize_kv_cache(self.q_proj, hidden_states, self.num_key_value_groups):
forward_function = llama_attention_forward_4_41_quantized
else:
forward_function = llama_attention_forward_4_41_original
@ -1566,7 +1569,7 @@ def llama_attention_forward_4_38(
cache_position: Optional[torch.LongTensor] = None,
**kwargs
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.FloatTensor]]]:
if use_quantize_kv_cache(self.q_proj, hidden_states):
if use_quantize_kv_cache(self.q_proj, hidden_states, self.num_key_value_groups):
forward_function = llama_attention_forward_4_38_quantized
else:
forward_function = llama_attention_forward_4_38_original

View file

@ -83,7 +83,7 @@ def minicpm_attention_forward(
cache_position: Optional[torch.LongTensor] = None,
**kwargs
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.FloatTensor]]]:
if use_quantize_kv_cache(self.q_proj, hidden_states):
if use_quantize_kv_cache(self.q_proj, hidden_states, self.num_key_value_groups):
forward_function = minicpm_attention_forward_quantized
else:
forward_function = minicpm_attention_forward_original
@ -603,7 +603,9 @@ def minicpm_model_forward(
from ipex_llm.transformers.kv import DynamicFp8Cache
use_cache = use_cache if use_cache is not None else self.config.use_cache
input = input_ids if input_ids is not None else inputs_embeds
if use_cache and use_quantize_kv_cache(self.layers[0].mlp.up_proj, input):
if use_cache and use_quantize_kv_cache(self.layers[0].mlp.up_proj, input,
self.config.num_attention_heads //
self.config.num_key_value_heads):
if not isinstance(past_key_values, DynamicFp8Cache):
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
return minicpm_model_forward_internal(
@ -1051,7 +1053,7 @@ def minicpm_attention_forward_4_39(
cache_position: Optional[torch.LongTensor] = None,
**kwargs
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.FloatTensor]]]:
if use_quantize_kv_cache(self.q_proj, hidden_states):
if use_quantize_kv_cache(self.q_proj, hidden_states, self.num_key_value_groups):
forward_function = minicpm_attention_forward_quantized
else:
forward_function = minicpm_attention_forward_original_4_39

View file

@ -205,7 +205,8 @@ def mistral_model_forward_4_36(
from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicCompressCache
use_cache = use_cache if use_cache is not None else self.config.use_cache
if use_cache:
if use_quantize_kv_cache(self.layers[0].mlp.up_proj, input_ids):
if use_quantize_kv_cache(self.layers[0].mlp.up_proj, input_ids,
self.config.num_attention_heads//self.config.num_key_value_heads):
if not isinstance(past_key_values, DynamicFp8Cache):
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
elif should_use_compresskv(input_ids):
@ -237,7 +238,7 @@ def mistral_attention_forward(
use_cache: bool=False,
padding_mask: Optional[torch.Tensor]=None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if use_quantize_kv_cache(self.q_proj, hidden_states):
if use_quantize_kv_cache(self.q_proj, hidden_states, self.num_key_value_groups):
forward_function = mistral_attention_forward_quantized
else:
forward_function = mistral_attention_forward_original
@ -654,7 +655,7 @@ def mistral_attention_forward_4_36(
use_cache: bool=False,
**kwargs
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
if use_quantize_kv_cache(self.q_proj, hidden_states):
if use_quantize_kv_cache(self.q_proj, hidden_states, self.num_key_value_groups):
forward_function = mistral_attention_forward_4_36_quantized
else:
forward_function = mistral_attention_forward_4_36_original
@ -1110,7 +1111,7 @@ def mistral_attention_forward_4_39(
use_cache: bool=False,
**kwargs
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
if use_quantize_kv_cache(self.q_proj, hidden_states):
if use_quantize_kv_cache(self.q_proj, hidden_states, self.num_key_value_groups):
forward_function = mistral_attention_forward_4_36_quantized
else:
forward_function = mistral_attention_forward_4_39_original

View file

@ -114,7 +114,8 @@ def qwen2_model_forward(
inputs = input_ids if input_ids is not None else inputs_embeds
use_quantize_kv = (
self.config.hidden_size != 3584 # disable quantize kv in specific model
and use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs)
and use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs,
self.config.num_attention_heads//self.config.num_key_value_heads)
)
if use_cache:

View file

@ -75,7 +75,7 @@ def append_kv_cache(cache_k, cache_v, key_states, value_states):
return new_cache_k, new_cache_v
def use_quantize_kv_cache(linear: torch.nn.Module, x: torch.Tensor) -> bool:
def use_quantize_kv_cache(linear: torch.nn.Module, x: torch.Tensor, kv_group: int = 1) -> bool:
if os.environ.get("BIGDL_QUANTIZE_KV_CACHE", None) is not None:
warnings.warn(
"`BIGDL_QUANTIZE_KV_CACHE` is deprecated and will be removed in future releases. "
@ -87,13 +87,13 @@ def use_quantize_kv_cache(linear: torch.nn.Module, x: torch.Tensor) -> bool:
elif os.environ.get("IPEX_LLM_LOW_MEM", None) is not None:
return os.environ["IPEX_LLM_LOW_MEM"] == "1"
else:
return x.device.type == 'xpu' and kv_cache_device_check(x) \
return x.device.type == 'xpu' and kv_cache_device_check(x, kv_group) \
and hasattr(linear, "qtype") and \
linear.qtype != ggml_tensor_qtype["fp16"] and linear.qtype != ggml_tensor_qtype["bf16"]
def kv_cache_device_check(x: torch.Tensor) -> bool:
return get_xpu_device_type(x) == "mtl" or \
def kv_cache_device_check(x: torch.Tensor, kv_group: int) -> bool:
return (get_xpu_device_type(x) == "mtl" and kv_group <= 1) or \
((get_xpu_device_type(x) == "arc" or get_xpu_device_type(x) == "flex") and
1 < x.size(0) and x.size(0) <= 8)