LLM: fix ChatGLM2 native int4 stream output (#8733)
This commit is contained in:
		
							parent
							
								
									ca3e59a1dc
								
							
						
					
					
						commit
						77efcf7b1d
					
				
					 1 changed files with 7 additions and 4 deletions
				
			
		| 
						 | 
				
			
			@ -220,7 +220,7 @@ class ChatGLM(GenerationMixin):
 | 
			
		|||
            }
 | 
			
		||||
 | 
			
		||||
        n_past = 0
 | 
			
		||||
        output_tokens = []
 | 
			
		||||
        output_tokens = input_tokens
 | 
			
		||||
        for i in range(max_tokens):
 | 
			
		||||
            token = self.forward(input_ids=input_tokens,
 | 
			
		||||
                                 n_past=n_past,
 | 
			
		||||
| 
						 | 
				
			
			@ -234,7 +234,7 @@ class ChatGLM(GenerationMixin):
 | 
			
		|||
                break
 | 
			
		||||
 | 
			
		||||
        text = self.detokenize(output_tokens)
 | 
			
		||||
        split_text = text
 | 
			
		||||
        split_text = text[len(prompt):]
 | 
			
		||||
        if stop != []:
 | 
			
		||||
            for stop_word in stop:
 | 
			
		||||
                split_text = split_text.split(stop_word)[0]
 | 
			
		||||
| 
						 | 
				
			
			@ -294,7 +294,8 @@ class ChatGLM(GenerationMixin):
 | 
			
		|||
            }
 | 
			
		||||
        else:
 | 
			
		||||
            n_past = 0
 | 
			
		||||
            output_tokens = []
 | 
			
		||||
            output_tokens = input_tokens
 | 
			
		||||
            history_text = prompt
 | 
			
		||||
            for i in range(max_tokens):
 | 
			
		||||
                token = self.forward(input_ids=input_tokens,
 | 
			
		||||
                                     n_past=n_past,
 | 
			
		||||
| 
						 | 
				
			
			@ -307,7 +308,9 @@ class ChatGLM(GenerationMixin):
 | 
			
		|||
                if token == self.eos_token():
 | 
			
		||||
                    print('\n')
 | 
			
		||||
                    break
 | 
			
		||||
                text = self.detokenize(token)
 | 
			
		||||
                text = self.detokenize(output_tokens)
 | 
			
		||||
                text = text[len(history_text):]
 | 
			
		||||
                history_text += text
 | 
			
		||||
                yield {
 | 
			
		||||
                    "id": completion_id,
 | 
			
		||||
                    "object": "text_completion",
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue