Fix low memory generation example issue in transformers 4.36 (#10702)
* update cache in low memory generate * update
This commit is contained in:
parent
6e7da0d92c
commit
1e817926ba
1 changed files with 28 additions and 11 deletions
|
|
@ -28,6 +28,7 @@ import torch.nn as nn
|
|||
from accelerate import init_empty_weights
|
||||
from safetensors.torch import load_file, save_file
|
||||
from tqdm.auto import tqdm
|
||||
import transformers
|
||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
|
||||
from transformers.generation import GenerationMixin
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
|
|
@ -247,6 +248,12 @@ class LowMemoryLlama(GenerationMixin):
|
|||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
from packaging import version
|
||||
trans_version = transformers.__version__
|
||||
if version.parse(trans_version) >= version.parse("4.36.0"):
|
||||
transformers_4_36 = True
|
||||
else:
|
||||
transformers_4_36 = False
|
||||
|
||||
# Reinit model and clean memory
|
||||
del self.model
|
||||
|
|
@ -257,18 +264,23 @@ class LowMemoryLlama(GenerationMixin):
|
|||
# Send batch to device
|
||||
inputs = input_ids.to(self.device)
|
||||
|
||||
current_shape = inputs.shape[1]
|
||||
# Set up kv cache
|
||||
kv_cache = {}
|
||||
if past_key_values is None:
|
||||
past_key_values = {}
|
||||
if transformers_4_36:
|
||||
from transformers.cache_utils import Cache, DynamicCache
|
||||
use_legacy_cache = not isinstance(past_key_values, Cache)
|
||||
if use_legacy_cache:
|
||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||
pre_shape = past_key_values.get_usable_length(current_shape)
|
||||
else:
|
||||
if past_key_values is not None:
|
||||
pre_shape = past_key_values[0][0].size(2)
|
||||
else:
|
||||
pre_shape = 0
|
||||
past_key_values = [None] * len(self.model.model.layers)
|
||||
|
||||
with torch.inference_mode():
|
||||
# Generate attention mask and position ids
|
||||
current_shape = inputs.shape[1]
|
||||
if past_key_values.get(self.layer_names[1], None):
|
||||
pre_shape = past_key_values[self.layer_names[1]][0].size(2)
|
||||
else:
|
||||
pre_shape = 0
|
||||
pos = self.position_ids[:, pre_shape : current_shape + pre_shape]
|
||||
attn = self.attention_mask[:, :, -current_shape:, - current_shape - pre_shape:]
|
||||
|
||||
|
|
@ -282,9 +294,14 @@ class LowMemoryLlama(GenerationMixin):
|
|||
if layer_name in ("model.embed_tokens", "model.norm", "lm_head"):
|
||||
inputs = layer(inputs)
|
||||
else:
|
||||
inputs, new_kv_cache = layer(inputs, use_cache=True, past_key_value=past_key_values.get(layer_name, None),
|
||||
decoder_layer_index = int(layer_name.split('.')[-1])
|
||||
past_key_value = past_key_values if transformers_4_36 else past_key_values[decoder_layer_index]
|
||||
inputs, new_kv_cache = layer(inputs, use_cache=True, past_key_value=past_key_value,
|
||||
position_ids=pos, attention_mask=attn)
|
||||
kv_cache[layer_name] = new_kv_cache
|
||||
if transformers_4_36:
|
||||
past_key_values = new_kv_cache
|
||||
else:
|
||||
past_key_values[decoder_layer_index] = new_kv_cache
|
||||
|
||||
# Delete weight before moving to('meta')
|
||||
for module in layer.modules():
|
||||
|
|
@ -296,7 +313,7 @@ class LowMemoryLlama(GenerationMixin):
|
|||
|
||||
result = CausalLMOutputWithPast(
|
||||
logits=inputs.detach(),
|
||||
past_key_values=kv_cache,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
return result
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue