Fix low memory generation example issue in transformers 4.36 (#10702)

* update cache in low memory generate

* update
This commit is contained in:
Jiao Wang 2024-04-09 09:56:52 -07:00 committed by GitHub
parent 6e7da0d92c
commit 1e817926ba
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -28,6 +28,7 @@ import torch.nn as nn
from accelerate import init_empty_weights from accelerate import init_empty_weights
from safetensors.torch import load_file, save_file from safetensors.torch import load_file, save_file
from tqdm.auto import tqdm from tqdm.auto import tqdm
import transformers
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from transformers.generation import GenerationMixin from transformers.generation import GenerationMixin
from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.modeling_outputs import CausalLMOutputWithPast
@ -247,6 +248,12 @@ class LowMemoryLlama(GenerationMixin):
*args, *args,
**kwargs, **kwargs,
): ):
from packaging import version
trans_version = transformers.__version__
if version.parse(trans_version) >= version.parse("4.36.0"):
transformers_4_36 = True
else:
transformers_4_36 = False
# Reinit model and clean memory # Reinit model and clean memory
del self.model del self.model
@ -257,18 +264,23 @@ class LowMemoryLlama(GenerationMixin):
# Send batch to device # Send batch to device
inputs = input_ids.to(self.device) inputs = input_ids.to(self.device)
current_shape = inputs.shape[1]
# Set up kv cache # Set up kv cache
kv_cache = {} if transformers_4_36:
if past_key_values is None: from transformers.cache_utils import Cache, DynamicCache
past_key_values = {} use_legacy_cache = not isinstance(past_key_values, Cache)
if use_legacy_cache:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
pre_shape = past_key_values.get_usable_length(current_shape)
else:
if past_key_values is not None:
pre_shape = past_key_values[0][0].size(2)
else:
pre_shape = 0
past_key_values = [None] * len(self.model.model.layers)
with torch.inference_mode(): with torch.inference_mode():
# Generate attention mask and position ids # Generate attention mask and position ids
current_shape = inputs.shape[1]
if past_key_values.get(self.layer_names[1], None):
pre_shape = past_key_values[self.layer_names[1]][0].size(2)
else:
pre_shape = 0
pos = self.position_ids[:, pre_shape : current_shape + pre_shape] pos = self.position_ids[:, pre_shape : current_shape + pre_shape]
attn = self.attention_mask[:, :, -current_shape:, - current_shape - pre_shape:] attn = self.attention_mask[:, :, -current_shape:, - current_shape - pre_shape:]
@ -282,9 +294,14 @@ class LowMemoryLlama(GenerationMixin):
if layer_name in ("model.embed_tokens", "model.norm", "lm_head"): if layer_name in ("model.embed_tokens", "model.norm", "lm_head"):
inputs = layer(inputs) inputs = layer(inputs)
else: else:
inputs, new_kv_cache = layer(inputs, use_cache=True, past_key_value=past_key_values.get(layer_name, None), decoder_layer_index = int(layer_name.split('.')[-1])
past_key_value = past_key_values if transformers_4_36 else past_key_values[decoder_layer_index]
inputs, new_kv_cache = layer(inputs, use_cache=True, past_key_value=past_key_value,
position_ids=pos, attention_mask=attn) position_ids=pos, attention_mask=attn)
kv_cache[layer_name] = new_kv_cache if transformers_4_36:
past_key_values = new_kv_cache
else:
past_key_values[decoder_layer_index] = new_kv_cache
# Delete weight before moving to('meta') # Delete weight before moving to('meta')
for module in layer.modules(): for module in layer.modules():
@ -296,7 +313,7 @@ class LowMemoryLlama(GenerationMixin):
result = CausalLMOutputWithPast( result = CausalLMOutputWithPast(
logits=inputs.detach(), logits=inputs.detach(),
past_key_values=kv_cache, past_key_values=past_key_values,
) )
return result return result