disable default quantize_kv of GQA on MTL (#11679)
* disable default quantizekv of gqa in mtl * fix stype * fix stype * fix stype * fix stype * fix stype * fix stype
This commit is contained in:
parent
c02003925b
commit
9b36877897
5 changed files with 25 additions and 18 deletions
|
|
@ -118,7 +118,8 @@ def llama_model_forward_4_36(
|
|||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
input = input_ids if input_ids is not None else inputs_embeds
|
||||
if use_cache:
|
||||
if use_quantize_kv_cache(self.layers[0].mlp.up_proj, input):
|
||||
if use_quantize_kv_cache(self.layers[0].mlp.up_proj, input,
|
||||
self.config.num_attention_heads//self.config.num_key_value_heads):
|
||||
if not isinstance(past_key_values, DynamicFp8Cache):
|
||||
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
||||
elif should_use_compresskv(input):
|
||||
|
|
@ -157,7 +158,8 @@ def llama_model_forward_4_38(
|
|||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
input = input_ids if input_ids is not None else inputs_embeds
|
||||
if use_cache:
|
||||
if use_quantize_kv_cache(self.layers[0].mlp.up_proj, input):
|
||||
if use_quantize_kv_cache(self.layers[0].mlp.up_proj, input,
|
||||
self.config.num_attention_heads//self.config.num_key_value_heads):
|
||||
if not isinstance(past_key_values, DynamicFp8Cache):
|
||||
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
||||
elif should_use_compresskv(input):
|
||||
|
|
@ -197,7 +199,8 @@ def llama_model_forward_4_41(
|
|||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
input = input_ids if input_ids is not None else inputs_embeds
|
||||
if use_cache:
|
||||
if use_quantize_kv_cache(self.layers[0].mlp.up_proj, input):
|
||||
if use_quantize_kv_cache(self.layers[0].mlp.up_proj, input,
|
||||
self.config.num_attention_heads//self.config.num_key_value_heads):
|
||||
if not isinstance(past_key_values, DynamicFp8Cache):
|
||||
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
||||
elif should_use_compresskv(input):
|
||||
|
|
@ -425,7 +428,7 @@ def llama_attention_forward_4_31(
|
|||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
if use_quantize_kv_cache(self.q_proj, hidden_states):
|
||||
if use_quantize_kv_cache(self.q_proj, hidden_states, self.num_key_value_groups):
|
||||
forward_function = llama_attention_forward_4_31_quantized
|
||||
else:
|
||||
forward_function = llama_attention_forward_4_31_original
|
||||
|
|
@ -1027,7 +1030,7 @@ def llama_attention_forward_4_41(
|
|||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.FloatTensor]]]:
|
||||
if use_quantize_kv_cache(self.q_proj, hidden_states):
|
||||
if use_quantize_kv_cache(self.q_proj, hidden_states, self.num_key_value_groups):
|
||||
forward_function = llama_attention_forward_4_41_quantized
|
||||
else:
|
||||
forward_function = llama_attention_forward_4_41_original
|
||||
|
|
@ -1566,7 +1569,7 @@ def llama_attention_forward_4_38(
|
|||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.FloatTensor]]]:
|
||||
if use_quantize_kv_cache(self.q_proj, hidden_states):
|
||||
if use_quantize_kv_cache(self.q_proj, hidden_states, self.num_key_value_groups):
|
||||
forward_function = llama_attention_forward_4_38_quantized
|
||||
else:
|
||||
forward_function = llama_attention_forward_4_38_original
|
||||
|
|
|
|||
|
|
@ -83,7 +83,7 @@ def minicpm_attention_forward(
|
|||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.FloatTensor]]]:
|
||||
if use_quantize_kv_cache(self.q_proj, hidden_states):
|
||||
if use_quantize_kv_cache(self.q_proj, hidden_states, self.num_key_value_groups):
|
||||
forward_function = minicpm_attention_forward_quantized
|
||||
else:
|
||||
forward_function = minicpm_attention_forward_original
|
||||
|
|
@ -603,7 +603,9 @@ def minicpm_model_forward(
|
|||
from ipex_llm.transformers.kv import DynamicFp8Cache
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
input = input_ids if input_ids is not None else inputs_embeds
|
||||
if use_cache and use_quantize_kv_cache(self.layers[0].mlp.up_proj, input):
|
||||
if use_cache and use_quantize_kv_cache(self.layers[0].mlp.up_proj, input,
|
||||
self.config.num_attention_heads //
|
||||
self.config.num_key_value_heads):
|
||||
if not isinstance(past_key_values, DynamicFp8Cache):
|
||||
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
||||
return minicpm_model_forward_internal(
|
||||
|
|
@ -1051,7 +1053,7 @@ def minicpm_attention_forward_4_39(
|
|||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.FloatTensor]]]:
|
||||
if use_quantize_kv_cache(self.q_proj, hidden_states):
|
||||
if use_quantize_kv_cache(self.q_proj, hidden_states, self.num_key_value_groups):
|
||||
forward_function = minicpm_attention_forward_quantized
|
||||
else:
|
||||
forward_function = minicpm_attention_forward_original_4_39
|
||||
|
|
|
|||
|
|
@ -205,7 +205,8 @@ def mistral_model_forward_4_36(
|
|||
from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicCompressCache
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
if use_cache:
|
||||
if use_quantize_kv_cache(self.layers[0].mlp.up_proj, input_ids):
|
||||
if use_quantize_kv_cache(self.layers[0].mlp.up_proj, input_ids,
|
||||
self.config.num_attention_heads//self.config.num_key_value_heads):
|
||||
if not isinstance(past_key_values, DynamicFp8Cache):
|
||||
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
||||
elif should_use_compresskv(input_ids):
|
||||
|
|
@ -237,7 +238,7 @@ def mistral_attention_forward(
|
|||
use_cache: bool=False,
|
||||
padding_mask: Optional[torch.Tensor]=None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
if use_quantize_kv_cache(self.q_proj, hidden_states):
|
||||
if use_quantize_kv_cache(self.q_proj, hidden_states, self.num_key_value_groups):
|
||||
forward_function = mistral_attention_forward_quantized
|
||||
else:
|
||||
forward_function = mistral_attention_forward_original
|
||||
|
|
@ -654,7 +655,7 @@ def mistral_attention_forward_4_36(
|
|||
use_cache: bool=False,
|
||||
**kwargs
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
|
||||
if use_quantize_kv_cache(self.q_proj, hidden_states):
|
||||
if use_quantize_kv_cache(self.q_proj, hidden_states, self.num_key_value_groups):
|
||||
forward_function = mistral_attention_forward_4_36_quantized
|
||||
else:
|
||||
forward_function = mistral_attention_forward_4_36_original
|
||||
|
|
@ -1110,7 +1111,7 @@ def mistral_attention_forward_4_39(
|
|||
use_cache: bool=False,
|
||||
**kwargs
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
|
||||
if use_quantize_kv_cache(self.q_proj, hidden_states):
|
||||
if use_quantize_kv_cache(self.q_proj, hidden_states, self.num_key_value_groups):
|
||||
forward_function = mistral_attention_forward_4_36_quantized
|
||||
else:
|
||||
forward_function = mistral_attention_forward_4_39_original
|
||||
|
|
|
|||
|
|
@ -114,7 +114,8 @@ def qwen2_model_forward(
|
|||
inputs = input_ids if input_ids is not None else inputs_embeds
|
||||
use_quantize_kv = (
|
||||
self.config.hidden_size != 3584 # disable quantize kv in specific model
|
||||
and use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs)
|
||||
and use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs,
|
||||
self.config.num_attention_heads//self.config.num_key_value_heads)
|
||||
)
|
||||
|
||||
if use_cache:
|
||||
|
|
|
|||
|
|
@ -75,7 +75,7 @@ def append_kv_cache(cache_k, cache_v, key_states, value_states):
|
|||
return new_cache_k, new_cache_v
|
||||
|
||||
|
||||
def use_quantize_kv_cache(linear: torch.nn.Module, x: torch.Tensor) -> bool:
|
||||
def use_quantize_kv_cache(linear: torch.nn.Module, x: torch.Tensor, kv_group: int = 1) -> bool:
|
||||
if os.environ.get("BIGDL_QUANTIZE_KV_CACHE", None) is not None:
|
||||
warnings.warn(
|
||||
"`BIGDL_QUANTIZE_KV_CACHE` is deprecated and will be removed in future releases. "
|
||||
|
|
@ -87,13 +87,13 @@ def use_quantize_kv_cache(linear: torch.nn.Module, x: torch.Tensor) -> bool:
|
|||
elif os.environ.get("IPEX_LLM_LOW_MEM", None) is not None:
|
||||
return os.environ["IPEX_LLM_LOW_MEM"] == "1"
|
||||
else:
|
||||
return x.device.type == 'xpu' and kv_cache_device_check(x) \
|
||||
return x.device.type == 'xpu' and kv_cache_device_check(x, kv_group) \
|
||||
and hasattr(linear, "qtype") and \
|
||||
linear.qtype != ggml_tensor_qtype["fp16"] and linear.qtype != ggml_tensor_qtype["bf16"]
|
||||
|
||||
|
||||
def kv_cache_device_check(x: torch.Tensor) -> bool:
|
||||
return get_xpu_device_type(x) == "mtl" or \
|
||||
def kv_cache_device_check(x: torch.Tensor, kv_group: int) -> bool:
|
||||
return (get_xpu_device_type(x) == "mtl" and kv_group <= 1) or \
|
||||
((get_xpu_device_type(x) == "arc" or get_xpu_device_type(x) == "flex") and
|
||||
1 < x.size(0) and x.size(0) <= 8)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue