Fix speculative decoding bug (#10855)
This commit is contained in:
		
							parent
							
								
									c9dee6cd0e
								
							
						
					
					
						commit
						92ea54b512
					
				
					 2 changed files with 8 additions and 3 deletions
				
			
		| 
						 | 
				
			
			@ -43,7 +43,7 @@ if __name__ == '__main__':
 | 
			
		|||
                        help='Prompt to infer')
 | 
			
		||||
    parser.add_argument('--precision', type=str, default='bf16',
 | 
			
		||||
                        help='Main model Precision')
 | 
			
		||||
    parser.add_argument('--n_predict', type=int, default=128,
 | 
			
		||||
    parser.add_argument('--n-predict', type=int, default=128,
 | 
			
		||||
                        help='Max tokens to predict')
 | 
			
		||||
    parser.add_argument('--max-draft', type=int, default=8,
 | 
			
		||||
                        help='Max draft')
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -74,8 +74,7 @@ def generate(
 | 
			
		|||
            for var in ['max_step_draft', 'th_stop_draft', 'hf_adjust',
 | 
			
		||||
                        'auto_th_stop_draft', 'auto_parameters', 'min_step_draft',
 | 
			
		||||
                        'th_batch_num']:
 | 
			
		||||
                value = kwargs.pop(var, None)
 | 
			
		||||
            del self.draft_model
 | 
			
		||||
                kwargs.pop(var, None)
 | 
			
		||||
            return original_generate(self,
 | 
			
		||||
                                     inputs=inputs,
 | 
			
		||||
                                     generation_config=generation_config,
 | 
			
		||||
| 
						 | 
				
			
			@ -100,6 +99,12 @@ def generate(
 | 
			
		|||
                                         draft_model=self.draft_model,
 | 
			
		||||
                                         **new_speculative_kwargs)
 | 
			
		||||
    else:
 | 
			
		||||
        # When `draft_model` is false, these attributes
 | 
			
		||||
        # related to speculative decoding should be removed
 | 
			
		||||
        for var in ['max_step_draft', 'th_stop_draft', 'hf_adjust',
 | 
			
		||||
                    'auto_th_stop_draft', 'auto_parameters', 'min_step_draft',
 | 
			
		||||
                    'th_batch_num']:
 | 
			
		||||
            kwargs.pop(var, None)
 | 
			
		||||
        return original_generate(self,
 | 
			
		||||
                                 inputs=inputs,
 | 
			
		||||
                                 generation_config=generation_config,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue