Fix low memory generation example issue in transformers 4.36 (#10702)
* update cache in low memory generate * update
This commit is contained in:
		
							parent
							
								
									6e7da0d92c
								
							
						
					
					
						commit
						1e817926ba
					
				
					 1 changed files with 28 additions and 11 deletions
				
			
		| 
						 | 
					@ -28,6 +28,7 @@ import torch.nn as nn
 | 
				
			||||||
from accelerate import init_empty_weights
 | 
					from accelerate import init_empty_weights
 | 
				
			||||||
from safetensors.torch import load_file, save_file
 | 
					from safetensors.torch import load_file, save_file
 | 
				
			||||||
from tqdm.auto import tqdm
 | 
					from tqdm.auto import tqdm
 | 
				
			||||||
 | 
					import transformers
 | 
				
			||||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
 | 
					from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
 | 
				
			||||||
from transformers.generation import GenerationMixin
 | 
					from transformers.generation import GenerationMixin
 | 
				
			||||||
from transformers.modeling_outputs import CausalLMOutputWithPast
 | 
					from transformers.modeling_outputs import CausalLMOutputWithPast
 | 
				
			||||||
| 
						 | 
					@ -247,6 +248,12 @@ class LowMemoryLlama(GenerationMixin):
 | 
				
			||||||
                *args,
 | 
					                *args,
 | 
				
			||||||
                **kwargs,
 | 
					                **kwargs,
 | 
				
			||||||
        ):
 | 
					        ):
 | 
				
			||||||
 | 
					        from packaging import version
 | 
				
			||||||
 | 
					        trans_version = transformers.__version__
 | 
				
			||||||
 | 
					        if version.parse(trans_version) >= version.parse("4.36.0"):
 | 
				
			||||||
 | 
					            transformers_4_36 = True
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            transformers_4_36 = False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Reinit model and clean memory
 | 
					        # Reinit model and clean memory
 | 
				
			||||||
        del self.model
 | 
					        del self.model
 | 
				
			||||||
| 
						 | 
					@ -257,18 +264,23 @@ class LowMemoryLlama(GenerationMixin):
 | 
				
			||||||
        # Send batch to device
 | 
					        # Send batch to device
 | 
				
			||||||
        inputs = input_ids.to(self.device)
 | 
					        inputs = input_ids.to(self.device)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        current_shape = inputs.shape[1]
 | 
				
			||||||
        # Set up kv cache
 | 
					        # Set up kv cache
 | 
				
			||||||
        kv_cache = {}
 | 
					        if transformers_4_36:
 | 
				
			||||||
        if past_key_values is None:
 | 
					            from transformers.cache_utils import Cache, DynamicCache
 | 
				
			||||||
            past_key_values = {}
 | 
					            use_legacy_cache = not isinstance(past_key_values, Cache)
 | 
				
			||||||
 | 
					            if use_legacy_cache:
 | 
				
			||||||
 | 
					                past_key_values = DynamicCache.from_legacy_cache(past_key_values)
 | 
				
			||||||
 | 
					            pre_shape = past_key_values.get_usable_length(current_shape)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            if past_key_values is not None:
 | 
				
			||||||
 | 
					                pre_shape = past_key_values[0][0].size(2)
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                pre_shape = 0
 | 
				
			||||||
 | 
					                past_key_values = [None] * len(self.model.model.layers)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        with torch.inference_mode():
 | 
					        with torch.inference_mode():
 | 
				
			||||||
            # Generate attention mask and position ids
 | 
					            # Generate attention mask and position ids
 | 
				
			||||||
            current_shape = inputs.shape[1]
 | 
					 | 
				
			||||||
            if past_key_values.get(self.layer_names[1], None):
 | 
					 | 
				
			||||||
                pre_shape = past_key_values[self.layer_names[1]][0].size(2)
 | 
					 | 
				
			||||||
            else:
 | 
					 | 
				
			||||||
                pre_shape = 0
 | 
					 | 
				
			||||||
            pos = self.position_ids[:, pre_shape : current_shape + pre_shape]
 | 
					            pos = self.position_ids[:, pre_shape : current_shape + pre_shape]
 | 
				
			||||||
            attn = self.attention_mask[:, :, -current_shape:, - current_shape - pre_shape:]
 | 
					            attn = self.attention_mask[:, :, -current_shape:, - current_shape - pre_shape:]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -282,9 +294,14 @@ class LowMemoryLlama(GenerationMixin):
 | 
				
			||||||
                if layer_name in ("model.embed_tokens", "model.norm", "lm_head"):
 | 
					                if layer_name in ("model.embed_tokens", "model.norm", "lm_head"):
 | 
				
			||||||
                    inputs = layer(inputs)
 | 
					                    inputs = layer(inputs)
 | 
				
			||||||
                else:
 | 
					                else:
 | 
				
			||||||
                    inputs, new_kv_cache = layer(inputs, use_cache=True, past_key_value=past_key_values.get(layer_name, None),
 | 
					                    decoder_layer_index = int(layer_name.split('.')[-1])
 | 
				
			||||||
 | 
					                    past_key_value = past_key_values if transformers_4_36 else past_key_values[decoder_layer_index]
 | 
				
			||||||
 | 
					                    inputs, new_kv_cache = layer(inputs, use_cache=True, past_key_value=past_key_value,
 | 
				
			||||||
                                                 position_ids=pos, attention_mask=attn)
 | 
					                                                 position_ids=pos, attention_mask=attn)
 | 
				
			||||||
                    kv_cache[layer_name] = new_kv_cache
 | 
					                    if transformers_4_36:
 | 
				
			||||||
 | 
					                        past_key_values = new_kv_cache
 | 
				
			||||||
 | 
					                    else:
 | 
				
			||||||
 | 
					                        past_key_values[decoder_layer_index] = new_kv_cache
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                # Delete weight before moving to('meta')
 | 
					                # Delete weight before moving to('meta')
 | 
				
			||||||
                for module in layer.modules():
 | 
					                for module in layer.modules():
 | 
				
			||||||
| 
						 | 
					@ -296,7 +313,7 @@ class LowMemoryLlama(GenerationMixin):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        result = CausalLMOutputWithPast(
 | 
					        result = CausalLMOutputWithPast(
 | 
				
			||||||
            logits=inputs.detach(),
 | 
					            logits=inputs.detach(),
 | 
				
			||||||
            past_key_values=kv_cache,
 | 
					            past_key_values=past_key_values,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        return result
 | 
					        return result
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue