disable default quantize_kv of GQA on MTL (#11679)
* disable default quantizekv of gqa in mtl * fix stype * fix stype * fix stype * fix stype * fix stype * fix stype
This commit is contained in:
parent
c02003925b
commit
9b36877897
5 changed files with 25 additions and 18 deletions
|
|
@ -118,7 +118,8 @@ def llama_model_forward_4_36(
|
||||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
input = input_ids if input_ids is not None else inputs_embeds
|
input = input_ids if input_ids is not None else inputs_embeds
|
||||||
if use_cache:
|
if use_cache:
|
||||||
if use_quantize_kv_cache(self.layers[0].mlp.up_proj, input):
|
if use_quantize_kv_cache(self.layers[0].mlp.up_proj, input,
|
||||||
|
self.config.num_attention_heads//self.config.num_key_value_heads):
|
||||||
if not isinstance(past_key_values, DynamicFp8Cache):
|
if not isinstance(past_key_values, DynamicFp8Cache):
|
||||||
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
||||||
elif should_use_compresskv(input):
|
elif should_use_compresskv(input):
|
||||||
|
|
@ -157,7 +158,8 @@ def llama_model_forward_4_38(
|
||||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
input = input_ids if input_ids is not None else inputs_embeds
|
input = input_ids if input_ids is not None else inputs_embeds
|
||||||
if use_cache:
|
if use_cache:
|
||||||
if use_quantize_kv_cache(self.layers[0].mlp.up_proj, input):
|
if use_quantize_kv_cache(self.layers[0].mlp.up_proj, input,
|
||||||
|
self.config.num_attention_heads//self.config.num_key_value_heads):
|
||||||
if not isinstance(past_key_values, DynamicFp8Cache):
|
if not isinstance(past_key_values, DynamicFp8Cache):
|
||||||
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
||||||
elif should_use_compresskv(input):
|
elif should_use_compresskv(input):
|
||||||
|
|
@ -197,7 +199,8 @@ def llama_model_forward_4_41(
|
||||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
input = input_ids if input_ids is not None else inputs_embeds
|
input = input_ids if input_ids is not None else inputs_embeds
|
||||||
if use_cache:
|
if use_cache:
|
||||||
if use_quantize_kv_cache(self.layers[0].mlp.up_proj, input):
|
if use_quantize_kv_cache(self.layers[0].mlp.up_proj, input,
|
||||||
|
self.config.num_attention_heads//self.config.num_key_value_heads):
|
||||||
if not isinstance(past_key_values, DynamicFp8Cache):
|
if not isinstance(past_key_values, DynamicFp8Cache):
|
||||||
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
||||||
elif should_use_compresskv(input):
|
elif should_use_compresskv(input):
|
||||||
|
|
@ -425,7 +428,7 @@ def llama_attention_forward_4_31(
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
if use_quantize_kv_cache(self.q_proj, hidden_states):
|
if use_quantize_kv_cache(self.q_proj, hidden_states, self.num_key_value_groups):
|
||||||
forward_function = llama_attention_forward_4_31_quantized
|
forward_function = llama_attention_forward_4_31_quantized
|
||||||
else:
|
else:
|
||||||
forward_function = llama_attention_forward_4_31_original
|
forward_function = llama_attention_forward_4_31_original
|
||||||
|
|
@ -1027,7 +1030,7 @@ def llama_attention_forward_4_41(
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.FloatTensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.FloatTensor]]]:
|
||||||
if use_quantize_kv_cache(self.q_proj, hidden_states):
|
if use_quantize_kv_cache(self.q_proj, hidden_states, self.num_key_value_groups):
|
||||||
forward_function = llama_attention_forward_4_41_quantized
|
forward_function = llama_attention_forward_4_41_quantized
|
||||||
else:
|
else:
|
||||||
forward_function = llama_attention_forward_4_41_original
|
forward_function = llama_attention_forward_4_41_original
|
||||||
|
|
@ -1566,7 +1569,7 @@ def llama_attention_forward_4_38(
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.FloatTensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.FloatTensor]]]:
|
||||||
if use_quantize_kv_cache(self.q_proj, hidden_states):
|
if use_quantize_kv_cache(self.q_proj, hidden_states, self.num_key_value_groups):
|
||||||
forward_function = llama_attention_forward_4_38_quantized
|
forward_function = llama_attention_forward_4_38_quantized
|
||||||
else:
|
else:
|
||||||
forward_function = llama_attention_forward_4_38_original
|
forward_function = llama_attention_forward_4_38_original
|
||||||
|
|
|
||||||
|
|
@ -83,7 +83,7 @@ def minicpm_attention_forward(
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.FloatTensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.FloatTensor]]]:
|
||||||
if use_quantize_kv_cache(self.q_proj, hidden_states):
|
if use_quantize_kv_cache(self.q_proj, hidden_states, self.num_key_value_groups):
|
||||||
forward_function = minicpm_attention_forward_quantized
|
forward_function = minicpm_attention_forward_quantized
|
||||||
else:
|
else:
|
||||||
forward_function = minicpm_attention_forward_original
|
forward_function = minicpm_attention_forward_original
|
||||||
|
|
@ -603,7 +603,9 @@ def minicpm_model_forward(
|
||||||
from ipex_llm.transformers.kv import DynamicFp8Cache
|
from ipex_llm.transformers.kv import DynamicFp8Cache
|
||||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
input = input_ids if input_ids is not None else inputs_embeds
|
input = input_ids if input_ids is not None else inputs_embeds
|
||||||
if use_cache and use_quantize_kv_cache(self.layers[0].mlp.up_proj, input):
|
if use_cache and use_quantize_kv_cache(self.layers[0].mlp.up_proj, input,
|
||||||
|
self.config.num_attention_heads //
|
||||||
|
self.config.num_key_value_heads):
|
||||||
if not isinstance(past_key_values, DynamicFp8Cache):
|
if not isinstance(past_key_values, DynamicFp8Cache):
|
||||||
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
||||||
return minicpm_model_forward_internal(
|
return minicpm_model_forward_internal(
|
||||||
|
|
@ -1051,7 +1053,7 @@ def minicpm_attention_forward_4_39(
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.FloatTensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.FloatTensor]]]:
|
||||||
if use_quantize_kv_cache(self.q_proj, hidden_states):
|
if use_quantize_kv_cache(self.q_proj, hidden_states, self.num_key_value_groups):
|
||||||
forward_function = minicpm_attention_forward_quantized
|
forward_function = minicpm_attention_forward_quantized
|
||||||
else:
|
else:
|
||||||
forward_function = minicpm_attention_forward_original_4_39
|
forward_function = minicpm_attention_forward_original_4_39
|
||||||
|
|
|
||||||
|
|
@ -205,7 +205,8 @@ def mistral_model_forward_4_36(
|
||||||
from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicCompressCache
|
from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicCompressCache
|
||||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
if use_cache:
|
if use_cache:
|
||||||
if use_quantize_kv_cache(self.layers[0].mlp.up_proj, input_ids):
|
if use_quantize_kv_cache(self.layers[0].mlp.up_proj, input_ids,
|
||||||
|
self.config.num_attention_heads//self.config.num_key_value_heads):
|
||||||
if not isinstance(past_key_values, DynamicFp8Cache):
|
if not isinstance(past_key_values, DynamicFp8Cache):
|
||||||
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
||||||
elif should_use_compresskv(input_ids):
|
elif should_use_compresskv(input_ids):
|
||||||
|
|
@ -237,7 +238,7 @@ def mistral_attention_forward(
|
||||||
use_cache: bool=False,
|
use_cache: bool=False,
|
||||||
padding_mask: Optional[torch.Tensor]=None,
|
padding_mask: Optional[torch.Tensor]=None,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
if use_quantize_kv_cache(self.q_proj, hidden_states):
|
if use_quantize_kv_cache(self.q_proj, hidden_states, self.num_key_value_groups):
|
||||||
forward_function = mistral_attention_forward_quantized
|
forward_function = mistral_attention_forward_quantized
|
||||||
else:
|
else:
|
||||||
forward_function = mistral_attention_forward_original
|
forward_function = mistral_attention_forward_original
|
||||||
|
|
@ -654,7 +655,7 @@ def mistral_attention_forward_4_36(
|
||||||
use_cache: bool=False,
|
use_cache: bool=False,
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
|
||||||
if use_quantize_kv_cache(self.q_proj, hidden_states):
|
if use_quantize_kv_cache(self.q_proj, hidden_states, self.num_key_value_groups):
|
||||||
forward_function = mistral_attention_forward_4_36_quantized
|
forward_function = mistral_attention_forward_4_36_quantized
|
||||||
else:
|
else:
|
||||||
forward_function = mistral_attention_forward_4_36_original
|
forward_function = mistral_attention_forward_4_36_original
|
||||||
|
|
@ -1110,7 +1111,7 @@ def mistral_attention_forward_4_39(
|
||||||
use_cache: bool=False,
|
use_cache: bool=False,
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
|
||||||
if use_quantize_kv_cache(self.q_proj, hidden_states):
|
if use_quantize_kv_cache(self.q_proj, hidden_states, self.num_key_value_groups):
|
||||||
forward_function = mistral_attention_forward_4_36_quantized
|
forward_function = mistral_attention_forward_4_36_quantized
|
||||||
else:
|
else:
|
||||||
forward_function = mistral_attention_forward_4_39_original
|
forward_function = mistral_attention_forward_4_39_original
|
||||||
|
|
|
||||||
|
|
@ -114,7 +114,8 @@ def qwen2_model_forward(
|
||||||
inputs = input_ids if input_ids is not None else inputs_embeds
|
inputs = input_ids if input_ids is not None else inputs_embeds
|
||||||
use_quantize_kv = (
|
use_quantize_kv = (
|
||||||
self.config.hidden_size != 3584 # disable quantize kv in specific model
|
self.config.hidden_size != 3584 # disable quantize kv in specific model
|
||||||
and use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs)
|
and use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs,
|
||||||
|
self.config.num_attention_heads//self.config.num_key_value_heads)
|
||||||
)
|
)
|
||||||
|
|
||||||
if use_cache:
|
if use_cache:
|
||||||
|
|
|
||||||
|
|
@ -75,7 +75,7 @@ def append_kv_cache(cache_k, cache_v, key_states, value_states):
|
||||||
return new_cache_k, new_cache_v
|
return new_cache_k, new_cache_v
|
||||||
|
|
||||||
|
|
||||||
def use_quantize_kv_cache(linear: torch.nn.Module, x: torch.Tensor) -> bool:
|
def use_quantize_kv_cache(linear: torch.nn.Module, x: torch.Tensor, kv_group: int = 1) -> bool:
|
||||||
if os.environ.get("BIGDL_QUANTIZE_KV_CACHE", None) is not None:
|
if os.environ.get("BIGDL_QUANTIZE_KV_CACHE", None) is not None:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"`BIGDL_QUANTIZE_KV_CACHE` is deprecated and will be removed in future releases. "
|
"`BIGDL_QUANTIZE_KV_CACHE` is deprecated and will be removed in future releases. "
|
||||||
|
|
@ -87,13 +87,13 @@ def use_quantize_kv_cache(linear: torch.nn.Module, x: torch.Tensor) -> bool:
|
||||||
elif os.environ.get("IPEX_LLM_LOW_MEM", None) is not None:
|
elif os.environ.get("IPEX_LLM_LOW_MEM", None) is not None:
|
||||||
return os.environ["IPEX_LLM_LOW_MEM"] == "1"
|
return os.environ["IPEX_LLM_LOW_MEM"] == "1"
|
||||||
else:
|
else:
|
||||||
return x.device.type == 'xpu' and kv_cache_device_check(x) \
|
return x.device.type == 'xpu' and kv_cache_device_check(x, kv_group) \
|
||||||
and hasattr(linear, "qtype") and \
|
and hasattr(linear, "qtype") and \
|
||||||
linear.qtype != ggml_tensor_qtype["fp16"] and linear.qtype != ggml_tensor_qtype["bf16"]
|
linear.qtype != ggml_tensor_qtype["fp16"] and linear.qtype != ggml_tensor_qtype["bf16"]
|
||||||
|
|
||||||
|
|
||||||
def kv_cache_device_check(x: torch.Tensor) -> bool:
|
def kv_cache_device_check(x: torch.Tensor, kv_group: int) -> bool:
|
||||||
return get_xpu_device_type(x) == "mtl" or \
|
return (get_xpu_device_type(x) == "mtl" and kv_group <= 1) or \
|
||||||
((get_xpu_device_type(x) == "arc" or get_xpu_device_type(x) == "flex") and
|
((get_xpu_device_type(x) == "arc" or get_xpu_device_type(x) == "flex") and
|
||||||
1 < x.size(0) and x.size(0) <= 8)
|
1 < x.size(0) and x.size(0) <= 8)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue