optimize phi3 memory usage (#11867)

This commit is contained in:
Yishuo Wang 2024-08-20 17:32:51 +08:00 committed by GitHub
parent 5b83493b1a
commit d4ee0a89f3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 26 additions and 3 deletions

View file

@ -121,6 +121,21 @@ class DynamicNormalCache(DynamicCache):
return self.key_cache[layer_idx], self.value_cache[layer_idx]
@classmethod
def from_reserved(cls, layers: int,
bsz: int, n_head: int, length: int, head_dim: int,
dtype: torch.dtype, device: torch.device):
past_key_values = cls()
for _i in range(layers):
k_cache, v_cache = init_kv_cache(
bsz, n_head, head_dim,
0, length + cls.KV_ALLOC_BLOCK_LENGTH,
dtype, device
)
past_key_values.key_cache.append(k_cache)
past_key_values.value_cache.append(v_cache)
return past_key_values
# Copied from transformers.models.llama.modeling_llama.repeat_kv
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:

View file

@ -254,9 +254,9 @@ def phi3_model_forward_wrapper(origin_model_forward):
):
# IPEX-LLM OPT: kv cache and quantize kv cache and sdp
use_cache = use_cache if use_cache is not None else self.config.use_cache
input = input_ids if input_ids is not None else inputs_embeds
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, input)
use_compress_kv = should_use_compresskv(input, input.shape[1])
inputs = input_ids if input_ids is not None else inputs_embeds
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs)
use_compress_kv = should_use_compresskv(inputs, inputs.shape[1])
if use_cache:
if use_compress_kv and not isinstance(past_key_values,
DynamicCompressCache):
@ -272,6 +272,14 @@ def phi3_model_forward_wrapper(origin_model_forward):
DynamicCompressCache
)):
past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values)
if past_key_values.get_seq_length() == 0:
n_layer = self.config.num_hidden_layers
n_head = self.config.num_attention_heads
head_dim = self.config.hidden_size // self.config.num_attention_heads
past_key_values = DynamicNormalCache.from_reserved(
n_layer, inputs.size(0), n_head, inputs.size(1), head_dim,
inputs.dtype, inputs.device
)
return origin_model_forward(
self=self,
input_ids=input_ids,