Fix low memory generation example issue in transformers 4.36 (#10702)

* update cache in low memory generate

* update
This commit is contained in:
Jiao Wang 2024-04-09 09:56:52 -07:00 committed by GitHub
parent 6e7da0d92c
commit 1e817926ba
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -28,6 +28,7 @@ import torch.nn as nn
from accelerate import init_empty_weights
from safetensors.torch import load_file, save_file
from tqdm.auto import tqdm
import transformers
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from transformers.generation import GenerationMixin
from transformers.modeling_outputs import CausalLMOutputWithPast
@ -247,6 +248,12 @@ class LowMemoryLlama(GenerationMixin):
*args,
**kwargs,
):
from packaging import version
trans_version = transformers.__version__
if version.parse(trans_version) >= version.parse("4.36.0"):
transformers_4_36 = True
else:
transformers_4_36 = False
# Reinit model and clean memory
del self.model
@ -257,18 +264,23 @@ class LowMemoryLlama(GenerationMixin):
# Send batch to device
inputs = input_ids.to(self.device)
current_shape = inputs.shape[1]
# Set up kv cache
kv_cache = {}
if past_key_values is None:
past_key_values = {}
if transformers_4_36:
from transformers.cache_utils import Cache, DynamicCache
use_legacy_cache = not isinstance(past_key_values, Cache)
if use_legacy_cache:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
pre_shape = past_key_values.get_usable_length(current_shape)
else:
if past_key_values is not None:
pre_shape = past_key_values[0][0].size(2)
else:
pre_shape = 0
past_key_values = [None] * len(self.model.model.layers)
with torch.inference_mode():
# Generate attention mask and position ids
current_shape = inputs.shape[1]
if past_key_values.get(self.layer_names[1], None):
pre_shape = past_key_values[self.layer_names[1]][0].size(2)
else:
pre_shape = 0
pos = self.position_ids[:, pre_shape : current_shape + pre_shape]
attn = self.attention_mask[:, :, -current_shape:, - current_shape - pre_shape:]
@ -282,9 +294,14 @@ class LowMemoryLlama(GenerationMixin):
if layer_name in ("model.embed_tokens", "model.norm", "lm_head"):
inputs = layer(inputs)
else:
inputs, new_kv_cache = layer(inputs, use_cache=True, past_key_value=past_key_values.get(layer_name, None),
decoder_layer_index = int(layer_name.split('.')[-1])
past_key_value = past_key_values if transformers_4_36 else past_key_values[decoder_layer_index]
inputs, new_kv_cache = layer(inputs, use_cache=True, past_key_value=past_key_value,
position_ids=pos, attention_mask=attn)
kv_cache[layer_name] = new_kv_cache
if transformers_4_36:
past_key_values = new_kv_cache
else:
past_key_values[decoder_layer_index] = new_kv_cache
# Delete weight before moving to('meta')
for module in layer.modules():
@ -296,7 +313,7 @@ class LowMemoryLlama(GenerationMixin):
result = CausalLMOutputWithPast(
logits=inputs.detach(),
past_key_values=kv_cache,
past_key_values=past_key_values,
)
return result