update quantize kv cache condition (#12681)
This commit is contained in:
		
							parent
							
								
									5d8081afbc
								
							
						
					
					
						commit
						7234c9b27b
					
				
					 20 changed files with 75 additions and 37 deletions
				
			
		| 
						 | 
				
			
			@ -73,7 +73,9 @@ def baichuan_model_7b_forward(
 | 
			
		|||
    if use_cache:
 | 
			
		||||
        inputs = input_ids if input_ids is not None else inputs_embeds
 | 
			
		||||
        use_compress_kv = should_use_compresskv(inputs, inputs.shape[1])
 | 
			
		||||
        use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs)
 | 
			
		||||
        use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs,
 | 
			
		||||
                                                self.config.num_attention_heads,
 | 
			
		||||
                                                self.config.num_attention_heads)
 | 
			
		||||
        if use_compress_kv and not isinstance(past_key_values,
 | 
			
		||||
                                              DynamicCompressCache):
 | 
			
		||||
            if use_quantize_kv:
 | 
			
		||||
| 
						 | 
				
			
			@ -246,8 +248,6 @@ def baichuan_attention_forward_7b(
 | 
			
		|||
        key_states = key_states.to(hidden_states.dtype)
 | 
			
		||||
 | 
			
		||||
    # IPEX-LLM OPT: kv cache and quantize kv
 | 
			
		||||
    use_quantize_kv = use_quantize_kv_cache(self.W_pack, hidden_states)
 | 
			
		||||
 | 
			
		||||
    # [CompressKV]
 | 
			
		||||
    if use_compresskv:
 | 
			
		||||
        enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value,
 | 
			
		||||
| 
						 | 
				
			
			@ -258,6 +258,8 @@ def baichuan_attention_forward_7b(
 | 
			
		|||
            query_states, attention_mask, 1,
 | 
			
		||||
            self.config, enough_kv_room, KV_CACHE_ALLOC_BLOCK_LENGTH)
 | 
			
		||||
    else:
 | 
			
		||||
        use_quantize_kv = use_quantize_kv_cache(self.W_pack, hidden_states,
 | 
			
		||||
                                                self.num_heads, self.num_heads)
 | 
			
		||||
        key_states, value_states = update_past_key_value(
 | 
			
		||||
            past_key_value, key_states, value_states,
 | 
			
		||||
            kv_seq_len, use_quantize_kv, device
 | 
			
		||||
| 
						 | 
				
			
			@ -308,7 +310,8 @@ def baichuan_attention_forward_13b(
 | 
			
		|||
        kv_seq_len += past_key_value[0].shape[2]
 | 
			
		||||
 | 
			
		||||
    # IPEX-LLM OPT: kv cache and quantize kv
 | 
			
		||||
    use_quantize_kv = use_quantize_kv_cache(self.W_pack, hidden_states)
 | 
			
		||||
    use_quantize_kv = use_quantize_kv_cache(self.W_pack, hidden_states,
 | 
			
		||||
                                            self.num_heads, self.num_heads)
 | 
			
		||||
    key_states, value_states = update_past_key_value(
 | 
			
		||||
        past_key_value, key_states, value_states,
 | 
			
		||||
        kv_seq_len, use_quantize_kv, device
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -63,8 +63,13 @@ def chatglm2_model_forward(
 | 
			
		|||
 | 
			
		||||
    if use_cache:
 | 
			
		||||
        use_compress_kv = should_use_compresskv(input_ids, input_ids.shape[1])
 | 
			
		||||
        n_heads = self.config.num_attention_heads
 | 
			
		||||
        if self.config.multi_query_attention:
 | 
			
		||||
            n_kv_heads = self.config.multi_query_group_num
 | 
			
		||||
        else:
 | 
			
		||||
            n_kv_heads = n_heads
 | 
			
		||||
        use_quantize_kv = use_quantize_kv_cache(self.encoder.layers[0].mlp.gate_proj,
 | 
			
		||||
                                                input_ids)
 | 
			
		||||
                                                input_ids, n_heads, n_kv_heads)
 | 
			
		||||
        if use_compress_kv and not isinstance(past_key_values,
 | 
			
		||||
                                              DynamicCompressCache):
 | 
			
		||||
            if use_quantize_kv:
 | 
			
		||||
| 
						 | 
				
			
			@ -257,8 +262,6 @@ def chatglm2_attention_forward(
 | 
			
		|||
        key_states[..., :rot_dim] = k_rot[...]
 | 
			
		||||
 | 
			
		||||
    # IPEX-LLM OPT: kv cache and quantize kv
 | 
			
		||||
    use_quantize_kv = use_quantize_kv_cache(self.query_key_value, query_states)
 | 
			
		||||
 | 
			
		||||
    # [CompressKV]
 | 
			
		||||
    if use_compresskv:
 | 
			
		||||
        from transformers.configuration_utils import PretrainedConfig
 | 
			
		||||
| 
						 | 
				
			
			@ -272,6 +275,8 @@ def chatglm2_attention_forward(
 | 
			
		|||
            self.config, enough_kv_room, KV_CACHE_ALLOC_BLOCK_LENGTH
 | 
			
		||||
        )
 | 
			
		||||
    else:
 | 
			
		||||
        use_quantize_kv = use_quantize_kv_cache(self.query_key_value, query_states,
 | 
			
		||||
                                                n_head, n_kv_head)
 | 
			
		||||
        key_states, value_states = update_past_key_value(
 | 
			
		||||
            past_key_value, key_states, value_states,
 | 
			
		||||
            kv_seq_len, use_quantize_kv, hidden_states.device
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -55,8 +55,13 @@ def chatglm4_model_forward(
 | 
			
		|||
    if use_cache:
 | 
			
		||||
        inputs = input_ids if input_ids is not None else inputs_embeds
 | 
			
		||||
        use_compress_kv = should_use_compresskv(inputs, inputs.shape[1])
 | 
			
		||||
        use_quantize_kv = use_quantize_kv_cache(self.encoder.layers[0].mlp.gate_proj,
 | 
			
		||||
                                                inputs)
 | 
			
		||||
        n_heads = self.config.num_attention_heads
 | 
			
		||||
        if self.config.multi_query_attention:
 | 
			
		||||
            n_kv_heads = self.config.multi_query_group_num
 | 
			
		||||
        else:
 | 
			
		||||
            n_kv_heads = n_heads
 | 
			
		||||
        use_quantize_kv = use_quantize_kv_cache(self.encoder.layers[0].mlp.gate_proj, inputs,
 | 
			
		||||
                                                n_heads, n_kv_heads)
 | 
			
		||||
        if use_compress_kv and not isinstance(past_key_values,
 | 
			
		||||
                                              DynamicCompressCache):
 | 
			
		||||
            if use_quantize_kv:
 | 
			
		||||
| 
						 | 
				
			
			@ -211,8 +216,6 @@ def chatglm4_attention_forward(
 | 
			
		|||
        key_states[..., :rot_dim] = k_rot[...]
 | 
			
		||||
 | 
			
		||||
    # IPEX-LLM OPT: kv cache and quantize kv
 | 
			
		||||
    use_quantize_kv = use_quantize_kv_cache(self.query_key_value, query_states)
 | 
			
		||||
 | 
			
		||||
    # [CompressKV]
 | 
			
		||||
    if use_compresskv:
 | 
			
		||||
        from transformers.configuration_utils import PretrainedConfig
 | 
			
		||||
| 
						 | 
				
			
			@ -226,6 +229,8 @@ def chatglm4_attention_forward(
 | 
			
		|||
            self.config, enough_kv_room, KV_CACHE_ALLOC_BLOCK_LENGTH
 | 
			
		||||
        )
 | 
			
		||||
    else:
 | 
			
		||||
        use_quantize_kv = use_quantize_kv_cache(self.query_key_value, query_states,
 | 
			
		||||
                                                n_head, n_kv_head)
 | 
			
		||||
        key_states, value_states = update_past_key_value(
 | 
			
		||||
            past_key_value, key_states, value_states,
 | 
			
		||||
            kv_seq_len, use_quantize_kv, hidden_states.device
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -230,7 +230,7 @@ def chatglm4v_attention_forward(
 | 
			
		|||
        key_states[..., :rot_dim] = k_rot[...]
 | 
			
		||||
 | 
			
		||||
    # IPEX-LLM OPT: kv cache and quantize kv
 | 
			
		||||
    use_quantize_kv = use_quantize_kv_cache(self.query_key_value, query_states)
 | 
			
		||||
    use_quantize_kv = use_quantize_kv_cache(self.query_key_value, query_states, n_head, n_kv_head)
 | 
			
		||||
    key_states, value_states = update_past_key_value(
 | 
			
		||||
        past_key_value, key_states, value_states,
 | 
			
		||||
        kv_seq_len, use_quantize_kv, hidden_states.device
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -147,7 +147,7 @@ def glm_model_forward_wrapper(origin_forward):
 | 
			
		|||
        use_cache = use_cache if use_cache is not None else self.config.use_cache
 | 
			
		||||
        use_cache = use_cache or inputs.device.type == 'xpu'
 | 
			
		||||
        use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs,
 | 
			
		||||
                                                self.config.num_attention_heads //
 | 
			
		||||
                                                self.config.num_attention_heads,
 | 
			
		||||
                                                self.config.num_key_value_heads)
 | 
			
		||||
 | 
			
		||||
        if use_cache:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -87,7 +87,8 @@ def internlm_attention_forward(
 | 
			
		|||
        )
 | 
			
		||||
 | 
			
		||||
    # IPEX-LLM OPT: kv cache and quantzie kv cache
 | 
			
		||||
    use_quantize_kv = use_quantize_kv_cache(self.qkv_proj, hidden_states)
 | 
			
		||||
    use_quantize_kv = use_quantize_kv_cache(self.qkv_proj, hidden_states,
 | 
			
		||||
                                            self.num_heads, self.num_heads)
 | 
			
		||||
    key_states, value_states = update_past_key_value(
 | 
			
		||||
        past_key_value, key_states, value_states,
 | 
			
		||||
        kv_seq_len, use_quantize_kv, hidden_states.device
 | 
			
		||||
| 
						 | 
				
			
			@ -171,7 +172,8 @@ def internlm2_attention_forward(
 | 
			
		|||
        )
 | 
			
		||||
 | 
			
		||||
    # IPEX-LLM OPT: kv cache and quantzie kv cache
 | 
			
		||||
    use_quantize_kv = use_quantize_kv_cache(self.wqkv, hidden_states)
 | 
			
		||||
    use_quantize_kv = use_quantize_kv_cache(self.wqkv, hidden_states,
 | 
			
		||||
                                            self.num_heads, self.num_key_value_heads)
 | 
			
		||||
    key_states, value_states = update_past_key_value(
 | 
			
		||||
        past_key_value, key_states, value_states,
 | 
			
		||||
        kv_seq_len, use_quantize_kv, hidden_states.device
 | 
			
		||||
| 
						 | 
				
			
			@ -346,7 +348,8 @@ def internlm_xcomposser2_attention_forward(
 | 
			
		|||
            query_states, key_states, cos, sin, position_ids, "internlm")
 | 
			
		||||
 | 
			
		||||
    # IPEX-LLM OPT: kv cache and quantzie kv cache
 | 
			
		||||
    use_quantize_kv = use_quantize_kv_cache(self.wqkv, hidden_states)
 | 
			
		||||
    use_quantize_kv = use_quantize_kv_cache(self.wqkv, hidden_states,
 | 
			
		||||
                                            self.num_heads, self.num_key_value_heads)
 | 
			
		||||
    key_states, value_states = update_past_key_value(
 | 
			
		||||
        past_key_value, key_states, value_states,
 | 
			
		||||
        kv_seq_len, use_quantize_kv, device
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -72,7 +72,7 @@ def llama_model_forward(
 | 
			
		|||
    use_cache = True if inputs.device.type == "xpu" else use_cache
 | 
			
		||||
    use_quantize_kv = use_quantize_kv_cache(
 | 
			
		||||
        self.layers[0].mlp.down_proj, inputs,
 | 
			
		||||
        self.config.num_attention_heads // self.config.num_key_value_heads
 | 
			
		||||
        self.config.num_attention_heads, self.config.num_key_value_heads
 | 
			
		||||
    )
 | 
			
		||||
    use_compresskv = should_use_compresskv(inputs, inputs.shape[1]) or \
 | 
			
		||||
        isinstance(past_key_values, DynamicCompressCache)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -159,7 +159,7 @@ def minicpm_model_forward_wrapper(origin_forward):
 | 
			
		|||
        # IPEX-LLM OPT: kv cache and quantize kv cache
 | 
			
		||||
        inputs = input_ids if input_ids is not None else inputs_embeds
 | 
			
		||||
        use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs,
 | 
			
		||||
                                                self.config.num_attention_heads //
 | 
			
		||||
                                                self.config.num_attention_heads,
 | 
			
		||||
                                                self.config.num_key_value_heads)
 | 
			
		||||
        use_compress_kv = should_use_compresskv(inputs, inputs.shape[1]) or \
 | 
			
		||||
            isinstance(past_key_values, DynamicCompressCache)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -66,7 +66,9 @@ def minicpm3_model_forward_wrapper(origin_forward):
 | 
			
		|||
        inputs = input_ids if input_ids is not None else inputs_embeds
 | 
			
		||||
        use_cache = use_cache if use_cache is not None else self.config.use_cache
 | 
			
		||||
        use_cache = True if inputs.device.type == "xpu" else use_cache
 | 
			
		||||
        use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs)
 | 
			
		||||
        num_heads, num_kv_heads = self.config.num_attention_heads, self.config.num_key_value_heads
 | 
			
		||||
        use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs,
 | 
			
		||||
                                                num_heads, num_kv_heads)
 | 
			
		||||
        if use_cache:
 | 
			
		||||
            if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
 | 
			
		||||
                past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -71,7 +71,7 @@ def mistral_model_forward(
 | 
			
		|||
    use_cache = use_cache if use_cache is not None else self.config.use_cache
 | 
			
		||||
    use_cache = use_cache or inputs.device.type == 'xpu'
 | 
			
		||||
    use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs,
 | 
			
		||||
                                            self.config.num_attention_heads //
 | 
			
		||||
                                            self.config.num_attention_heads,
 | 
			
		||||
                                            self.config.num_key_value_heads)
 | 
			
		||||
    use_compress_kv = should_use_compresskv(inputs, inputs.size(1)) or \
 | 
			
		||||
        isinstance(past_key_values, DynamicCompressCache)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -113,7 +113,7 @@ def mllama_text_model_forward(
 | 
			
		|||
    use_cache = True if inputs.device.type == "xpu" else use_cache
 | 
			
		||||
    use_quantize_kv = use_quantize_kv_cache(
 | 
			
		||||
        self.layers[0].mlp.down_proj, inputs,
 | 
			
		||||
        self.config.num_attention_heads // self.config.num_key_value_heads
 | 
			
		||||
        self.config.num_attention_heads, self.config.num_key_value_heads
 | 
			
		||||
    )
 | 
			
		||||
    if use_cache:
 | 
			
		||||
        if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -249,7 +249,9 @@ def phi3_model_forward_wrapper(origin_model_forward):
 | 
			
		|||
        # IPEX-LLM OPT: kv cache and quantize kv cache and sdp
 | 
			
		||||
        use_cache = use_cache if use_cache is not None else self.config.use_cache
 | 
			
		||||
        inputs = input_ids if input_ids is not None else inputs_embeds
 | 
			
		||||
        use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs)
 | 
			
		||||
        num_heads, num_kv_heads = self.config.num_attention_heads, self.config.num_key_value_heads
 | 
			
		||||
        use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs,
 | 
			
		||||
                                                num_heads, num_kv_heads)
 | 
			
		||||
        use_compress_kv = should_use_compresskv(inputs, inputs.shape[1]) or \
 | 
			
		||||
            isinstance(past_key_values, DynamicCompressCache)
 | 
			
		||||
        if use_cache:
 | 
			
		||||
| 
						 | 
				
			
			@ -305,7 +307,9 @@ def phi3v_model_forward_wrapper(origin_model_forward):
 | 
			
		|||
    ):
 | 
			
		||||
        # IPEX-LLM OPT: kv cache and quantize kv cache and sdp
 | 
			
		||||
        use_cache = use_cache if use_cache is not None else self.config.use_cache
 | 
			
		||||
        use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, input_ids)
 | 
			
		||||
        num_heads, num_kv_heads = self.config.num_attention_heads, self.config.num_key_value_heads
 | 
			
		||||
        use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, input_ids,
 | 
			
		||||
                                                num_heads, num_kv_heads)
 | 
			
		||||
        if use_cache:
 | 
			
		||||
            if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
 | 
			
		||||
                past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -107,7 +107,8 @@ def qwen_attention_forward(
 | 
			
		|||
        query_states = query_states * logn_tensor.type_as(query_states).expand_as(query_states)
 | 
			
		||||
 | 
			
		||||
    # IPEX-LLM OPT: kv cache and quantzie kv cache
 | 
			
		||||
    use_quantize_kv = use_quantize_kv_cache(self.c_attn, hidden_states)
 | 
			
		||||
    use_quantize_kv = use_quantize_kv_cache(self.c_attn, hidden_states,
 | 
			
		||||
                                            self.num_heads, self.num_heads)
 | 
			
		||||
    key_states, value_states = update_past_key_value(
 | 
			
		||||
        past_key_value, key_states, value_states,
 | 
			
		||||
        kv_seq_len, use_quantize_kv, device
 | 
			
		||||
| 
						 | 
				
			
			@ -205,7 +206,8 @@ def qwen_attention_forward_registered(
 | 
			
		|||
        query_states = query_states * logn_tensor.type_as(query_states).expand_as(query_states)
 | 
			
		||||
 | 
			
		||||
    # IPEX-LLM OPT: kv cache and quantzie kv cache
 | 
			
		||||
    use_quantize_kv = use_quantize_kv_cache(self.c_attn, hidden_states)
 | 
			
		||||
    use_quantize_kv = use_quantize_kv_cache(self.c_attn, hidden_states,
 | 
			
		||||
                                            self.num_heads, self.num_heads)
 | 
			
		||||
    key_states, value_states = update_past_key_value(
 | 
			
		||||
        past_key_value, key_states, value_states,
 | 
			
		||||
        kv_seq_len, use_quantize_kv, device
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -113,10 +113,10 @@ def qwen2_model_forward(
 | 
			
		|||
    # ipex-llm changes start
 | 
			
		||||
    # IPEX-LLM OPT: kv cache and quantize kv cache
 | 
			
		||||
    inputs = input_ids if input_ids is not None else inputs_embeds
 | 
			
		||||
    num_heads, num_kv_heads = self.config.num_attention_heads, self.config.num_key_value_heads
 | 
			
		||||
    use_quantize_kv = (
 | 
			
		||||
        self.config.hidden_size != 3584     # disable quantize kv in specific model
 | 
			
		||||
        and use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs,
 | 
			
		||||
                                  self.config.num_attention_heads//self.config.num_key_value_heads)
 | 
			
		||||
        and use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs, num_heads, num_kv_heads)
 | 
			
		||||
    )
 | 
			
		||||
    use_compress_kv = should_use_compresskv(inputs, inputs.shape[1]) or \
 | 
			
		||||
        isinstance(past_key_values, DynamicCompressCache)
 | 
			
		||||
| 
						 | 
				
			
			@ -305,10 +305,11 @@ def qwen2_model_forward_4_42(
 | 
			
		|||
 | 
			
		||||
    # ipex-llm changes start
 | 
			
		||||
    # IPEX-LLM OPT: kv cache and quantize kv cache
 | 
			
		||||
    num_heads, num_kv_heads = self.config.num_attention_heads, self.config.num_key_value_heads
 | 
			
		||||
    use_quantize_kv = (
 | 
			
		||||
        self.config.hidden_size != 3584     # disable quantize kv in specific model
 | 
			
		||||
        and use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs_embeds,
 | 
			
		||||
                                  self.config.num_attention_heads//self.config.num_key_value_heads)
 | 
			
		||||
                                  num_heads, num_kv_heads)
 | 
			
		||||
    )
 | 
			
		||||
    use_compress_kv = should_use_compresskv(inputs_embeds, inputs_embeds.shape[1]) or \
 | 
			
		||||
        isinstance(past_key_values, DynamicCompressCache)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -73,8 +73,10 @@ def qwen2moe_model_forward(
 | 
			
		|||
    return_dict: Optional[bool] = None,
 | 
			
		||||
):
 | 
			
		||||
    use_cache = use_cache if use_cache is not None else self.config.use_cache
 | 
			
		||||
    input = input_ids if input_ids is not None else inputs_embeds
 | 
			
		||||
    use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.shared_expert.up_proj, input)
 | 
			
		||||
    inputs = input_ids if input_ids is not None else inputs_embeds
 | 
			
		||||
    num_heads, num_kv_heads = self.config.num_attention_heads, self.config.num_key_value_heads
 | 
			
		||||
    use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.shared_expert.up_proj, inputs,
 | 
			
		||||
                                            num_heads, num_kv_heads)
 | 
			
		||||
    if use_cache:
 | 
			
		||||
        if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
 | 
			
		||||
            past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -88,7 +88,9 @@ def qwen2_vl_model_forward(
 | 
			
		|||
    # IPEX-LLM OPT start: kv cache and quantize kv cache
 | 
			
		||||
    inputs = input_ids if input_ids is not None else inputs_embeds
 | 
			
		||||
    use_cache = True if inputs.device.type == "xpu" else use_cache
 | 
			
		||||
    use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs)
 | 
			
		||||
    num_heads, num_kv_heads = self.config.num_attention_heads, self.config.num_key_value_heads
 | 
			
		||||
    use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs,
 | 
			
		||||
                                            num_heads, num_kv_heads)
 | 
			
		||||
    if use_cache:
 | 
			
		||||
        if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
 | 
			
		||||
            past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -69,8 +69,10 @@ def stablelm_model_forward(
 | 
			
		|||
):
 | 
			
		||||
    # IPEX-LLM OPT: kv cache and quantize kv cache
 | 
			
		||||
    use_cache = use_cache if use_cache is not None else self.config.use_cache
 | 
			
		||||
    num_heads, num_kv_heads = self.config.num_attention_heads, self.config.num_key_value_heads
 | 
			
		||||
    use_quantize_kv = (self.layers[0].self_attn.head_dim in [64, 80, 96, 128]
 | 
			
		||||
                       and use_quantize_kv_cache(self.layers[0].mlp.up_proj, input_ids))
 | 
			
		||||
                       and use_quantize_kv_cache(self.layers[0].mlp.up_proj, input_ids,
 | 
			
		||||
                                                 num_heads, num_kv_heads))
 | 
			
		||||
    if use_cache:
 | 
			
		||||
        if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
 | 
			
		||||
            past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -132,7 +132,9 @@ def model_forward(
 | 
			
		|||
    return_dict: Optional[bool] = None,
 | 
			
		||||
):
 | 
			
		||||
    use_cache = use_cache if use_cache is not None else self.config.use_cache
 | 
			
		||||
    use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.c_fc, input_ids)
 | 
			
		||||
    num_heads, num_kv_heads = self.config.num_attention_heads, self.config.num_key_value_heads
 | 
			
		||||
    use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.c_fc, input_ids,
 | 
			
		||||
                                            num_heads, num_kv_heads)
 | 
			
		||||
    if use_cache:
 | 
			
		||||
        if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
 | 
			
		||||
            past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -74,7 +74,8 @@ def append_kv_cache(cache_k, cache_v, key_states, value_states):
 | 
			
		|||
    return new_cache_k, new_cache_v
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def use_quantize_kv_cache(linear: torch.nn.Module, x: torch.Tensor, kv_group: int = 1) -> bool:
 | 
			
		||||
def use_quantize_kv_cache(linear: torch.nn.Module, x: torch.Tensor,
 | 
			
		||||
                          num_heads: int, num_kv_heads: int) -> bool:
 | 
			
		||||
    if os.environ.get("BIGDL_QUANTIZE_KV_CACHE", None) is not None:
 | 
			
		||||
        warnings.warn(
 | 
			
		||||
            "`BIGDL_QUANTIZE_KV_CACHE` is deprecated and will be removed in future releases. "
 | 
			
		||||
| 
						 | 
				
			
			@ -90,8 +91,11 @@ def use_quantize_kv_cache(linear: torch.nn.Module, x: torch.Tensor, kv_group: in
 | 
			
		|||
    else:
 | 
			
		||||
        device_name = get_xpu_device_name(x.device)
 | 
			
		||||
        return (
 | 
			
		||||
            device_name in ["mtl", "lnl", "arl"] and kv_group == 1
 | 
			
		||||
            or device_name in ["arc", "bmg"] and x.size(0) > 1
 | 
			
		||||
            num_kv_heads >= 4
 | 
			
		||||
            and (
 | 
			
		||||
                device_name in ["mtl", "lnl", "arl"] and num_heads // num_kv_heads <= 4
 | 
			
		||||
                or device_name in ["arc", "bmg"] and x.size(0) > 1
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -158,7 +158,8 @@ def yuan_attention_forward(
 | 
			
		|||
                                                        "yuan")
 | 
			
		||||
 | 
			
		||||
    # IPEX-LLM OPT: kv cache and quantzie kv cache
 | 
			
		||||
    use_quantize_kv = use_quantize_kv_cache(self.qk_proj, hidden_states)
 | 
			
		||||
    use_quantize_kv = use_quantize_kv_cache(self.qk_proj, hidden_states,
 | 
			
		||||
                                            self.num_heads, self.num_heads)
 | 
			
		||||
    key_states, value_states = update_past_key_value(
 | 
			
		||||
        None if past_key_value is None else (past_key_value[0], past_key_value[1]),
 | 
			
		||||
        key_states, value_states,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue