optimize phi3 memory usage (#11867)
This commit is contained in:
parent
5b83493b1a
commit
d4ee0a89f3
2 changed files with 26 additions and 3 deletions
|
|
@ -121,6 +121,21 @@ class DynamicNormalCache(DynamicCache):
|
|||
|
||||
return self.key_cache[layer_idx], self.value_cache[layer_idx]
|
||||
|
||||
@classmethod
|
||||
def from_reserved(cls, layers: int,
|
||||
bsz: int, n_head: int, length: int, head_dim: int,
|
||||
dtype: torch.dtype, device: torch.device):
|
||||
past_key_values = cls()
|
||||
for _i in range(layers):
|
||||
k_cache, v_cache = init_kv_cache(
|
||||
bsz, n_head, head_dim,
|
||||
0, length + cls.KV_ALLOC_BLOCK_LENGTH,
|
||||
dtype, device
|
||||
)
|
||||
past_key_values.key_cache.append(k_cache)
|
||||
past_key_values.value_cache.append(v_cache)
|
||||
return past_key_values
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.repeat_kv
|
||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
|
|
|
|||
|
|
@ -254,9 +254,9 @@ def phi3_model_forward_wrapper(origin_model_forward):
|
|||
):
|
||||
# IPEX-LLM OPT: kv cache and quantize kv cache and sdp
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
input = input_ids if input_ids is not None else inputs_embeds
|
||||
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, input)
|
||||
use_compress_kv = should_use_compresskv(input, input.shape[1])
|
||||
inputs = input_ids if input_ids is not None else inputs_embeds
|
||||
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs)
|
||||
use_compress_kv = should_use_compresskv(inputs, inputs.shape[1])
|
||||
if use_cache:
|
||||
if use_compress_kv and not isinstance(past_key_values,
|
||||
DynamicCompressCache):
|
||||
|
|
@ -272,6 +272,14 @@ def phi3_model_forward_wrapper(origin_model_forward):
|
|||
DynamicCompressCache
|
||||
)):
|
||||
past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values)
|
||||
if past_key_values.get_seq_length() == 0:
|
||||
n_layer = self.config.num_hidden_layers
|
||||
n_head = self.config.num_attention_heads
|
||||
head_dim = self.config.hidden_size // self.config.num_attention_heads
|
||||
past_key_values = DynamicNormalCache.from_reserved(
|
||||
n_layer, inputs.size(0), n_head, inputs.size(1), head_dim,
|
||||
inputs.dtype, inputs.device
|
||||
)
|
||||
return origin_model_forward(
|
||||
self=self,
|
||||
input_ids=input_ids,
|
||||
|
|
|
|||
Loading…
Reference in a new issue