Fix low memory generation example issue in transformers 4.36 (#10702)
* update cache in low memory generate * update
This commit is contained in:
parent
6e7da0d92c
commit
1e817926ba
1 changed files with 28 additions and 11 deletions
|
|
@ -28,6 +28,7 @@ import torch.nn as nn
|
||||||
from accelerate import init_empty_weights
|
from accelerate import init_empty_weights
|
||||||
from safetensors.torch import load_file, save_file
|
from safetensors.torch import load_file, save_file
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
|
import transformers
|
||||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
|
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
|
||||||
from transformers.generation import GenerationMixin
|
from transformers.generation import GenerationMixin
|
||||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||||
|
|
@ -247,6 +248,12 @@ class LowMemoryLlama(GenerationMixin):
|
||||||
*args,
|
*args,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
from packaging import version
|
||||||
|
trans_version = transformers.__version__
|
||||||
|
if version.parse(trans_version) >= version.parse("4.36.0"):
|
||||||
|
transformers_4_36 = True
|
||||||
|
else:
|
||||||
|
transformers_4_36 = False
|
||||||
|
|
||||||
# Reinit model and clean memory
|
# Reinit model and clean memory
|
||||||
del self.model
|
del self.model
|
||||||
|
|
@ -257,18 +264,23 @@ class LowMemoryLlama(GenerationMixin):
|
||||||
# Send batch to device
|
# Send batch to device
|
||||||
inputs = input_ids.to(self.device)
|
inputs = input_ids.to(self.device)
|
||||||
|
|
||||||
|
current_shape = inputs.shape[1]
|
||||||
# Set up kv cache
|
# Set up kv cache
|
||||||
kv_cache = {}
|
if transformers_4_36:
|
||||||
if past_key_values is None:
|
from transformers.cache_utils import Cache, DynamicCache
|
||||||
past_key_values = {}
|
use_legacy_cache = not isinstance(past_key_values, Cache)
|
||||||
|
if use_legacy_cache:
|
||||||
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||||
|
pre_shape = past_key_values.get_usable_length(current_shape)
|
||||||
|
else:
|
||||||
|
if past_key_values is not None:
|
||||||
|
pre_shape = past_key_values[0][0].size(2)
|
||||||
|
else:
|
||||||
|
pre_shape = 0
|
||||||
|
past_key_values = [None] * len(self.model.model.layers)
|
||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
# Generate attention mask and position ids
|
# Generate attention mask and position ids
|
||||||
current_shape = inputs.shape[1]
|
|
||||||
if past_key_values.get(self.layer_names[1], None):
|
|
||||||
pre_shape = past_key_values[self.layer_names[1]][0].size(2)
|
|
||||||
else:
|
|
||||||
pre_shape = 0
|
|
||||||
pos = self.position_ids[:, pre_shape : current_shape + pre_shape]
|
pos = self.position_ids[:, pre_shape : current_shape + pre_shape]
|
||||||
attn = self.attention_mask[:, :, -current_shape:, - current_shape - pre_shape:]
|
attn = self.attention_mask[:, :, -current_shape:, - current_shape - pre_shape:]
|
||||||
|
|
||||||
|
|
@ -282,9 +294,14 @@ class LowMemoryLlama(GenerationMixin):
|
||||||
if layer_name in ("model.embed_tokens", "model.norm", "lm_head"):
|
if layer_name in ("model.embed_tokens", "model.norm", "lm_head"):
|
||||||
inputs = layer(inputs)
|
inputs = layer(inputs)
|
||||||
else:
|
else:
|
||||||
inputs, new_kv_cache = layer(inputs, use_cache=True, past_key_value=past_key_values.get(layer_name, None),
|
decoder_layer_index = int(layer_name.split('.')[-1])
|
||||||
|
past_key_value = past_key_values if transformers_4_36 else past_key_values[decoder_layer_index]
|
||||||
|
inputs, new_kv_cache = layer(inputs, use_cache=True, past_key_value=past_key_value,
|
||||||
position_ids=pos, attention_mask=attn)
|
position_ids=pos, attention_mask=attn)
|
||||||
kv_cache[layer_name] = new_kv_cache
|
if transformers_4_36:
|
||||||
|
past_key_values = new_kv_cache
|
||||||
|
else:
|
||||||
|
past_key_values[decoder_layer_index] = new_kv_cache
|
||||||
|
|
||||||
# Delete weight before moving to('meta')
|
# Delete weight before moving to('meta')
|
||||||
for module in layer.modules():
|
for module in layer.modules():
|
||||||
|
|
@ -296,7 +313,7 @@ class LowMemoryLlama(GenerationMixin):
|
||||||
|
|
||||||
result = CausalLMOutputWithPast(
|
result = CausalLMOutputWithPast(
|
||||||
logits=inputs.detach(),
|
logits=inputs.detach(),
|
||||||
past_key_values=kv_cache,
|
past_key_values=past_key_values,
|
||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue