From 1e817926ba530dd3b6b6585a898ac529dbd615b0 Mon Sep 17 00:00:00 2001 From: Jiao Wang Date: Tue, 9 Apr 2024 09:56:52 -0700 Subject: [PATCH] Fix low memory generation example issue in transformers 4.36 (#10702) * update cache in low memory generate * update --- .../Model/llama2/low_memory_generate.py | 39 +++++++++++++------ 1 file changed, 28 insertions(+), 11 deletions(-) diff --git a/python/llm/example/GPU/PyTorch-Models/Model/llama2/low_memory_generate.py b/python/llm/example/GPU/PyTorch-Models/Model/llama2/low_memory_generate.py index 55c9e70b..e5d2d65d 100644 --- a/python/llm/example/GPU/PyTorch-Models/Model/llama2/low_memory_generate.py +++ b/python/llm/example/GPU/PyTorch-Models/Model/llama2/low_memory_generate.py @@ -28,6 +28,7 @@ import torch.nn as nn from accelerate import init_empty_weights from safetensors.torch import load_file, save_file from tqdm.auto import tqdm +import transformers from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer from transformers.generation import GenerationMixin from transformers.modeling_outputs import CausalLMOutputWithPast @@ -247,6 +248,12 @@ class LowMemoryLlama(GenerationMixin): *args, **kwargs, ): + from packaging import version + trans_version = transformers.__version__ + if version.parse(trans_version) >= version.parse("4.36.0"): + transformers_4_36 = True + else: + transformers_4_36 = False # Reinit model and clean memory del self.model @@ -257,18 +264,23 @@ class LowMemoryLlama(GenerationMixin): # Send batch to device inputs = input_ids.to(self.device) + current_shape = inputs.shape[1] # Set up kv cache - kv_cache = {} - if past_key_values is None: - past_key_values = {} + if transformers_4_36: + from transformers.cache_utils import Cache, DynamicCache + use_legacy_cache = not isinstance(past_key_values, Cache) + if use_legacy_cache: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + pre_shape = past_key_values.get_usable_length(current_shape) + else: + if past_key_values is not None: + pre_shape = past_key_values[0][0].size(2) + else: + pre_shape = 0 + past_key_values = [None] * len(self.model.model.layers) with torch.inference_mode(): # Generate attention mask and position ids - current_shape = inputs.shape[1] - if past_key_values.get(self.layer_names[1], None): - pre_shape = past_key_values[self.layer_names[1]][0].size(2) - else: - pre_shape = 0 pos = self.position_ids[:, pre_shape : current_shape + pre_shape] attn = self.attention_mask[:, :, -current_shape:, - current_shape - pre_shape:] @@ -282,9 +294,14 @@ class LowMemoryLlama(GenerationMixin): if layer_name in ("model.embed_tokens", "model.norm", "lm_head"): inputs = layer(inputs) else: - inputs, new_kv_cache = layer(inputs, use_cache=True, past_key_value=past_key_values.get(layer_name, None), + decoder_layer_index = int(layer_name.split('.')[-1]) + past_key_value = past_key_values if transformers_4_36 else past_key_values[decoder_layer_index] + inputs, new_kv_cache = layer(inputs, use_cache=True, past_key_value=past_key_value, position_ids=pos, attention_mask=attn) - kv_cache[layer_name] = new_kv_cache + if transformers_4_36: + past_key_values = new_kv_cache + else: + past_key_values[decoder_layer_index] = new_kv_cache # Delete weight before moving to('meta') for module in layer.modules(): @@ -296,7 +313,7 @@ class LowMemoryLlama(GenerationMixin): result = CausalLMOutputWithPast( logits=inputs.detach(), - past_key_values=kv_cache, + past_key_values=past_key_values, ) return result