fix starcoder chinese output (#8773)
This commit is contained in:
		
							parent
							
								
									548f7a6cf7
								
							
						
					
					
						commit
						2ba2133613
					
				
					 2 changed files with 19 additions and 8 deletions
				
			
		| 
						 | 
				
			
			@ -227,8 +227,14 @@ class Starcoder(GenerationMixin):
 | 
			
		|||
                            match_str=match_str,
 | 
			
		||||
                            prompt=bytes(prompt, encoding='utf-8'),
 | 
			
		||||
                            buf=buf)
 | 
			
		||||
        s = str(buf, encoding='utf-8').rstrip("\x00")
 | 
			
		||||
 | 
			
		||||
        buf = buf.rstrip(b"\x00")
 | 
			
		||||
        s = ''
 | 
			
		||||
        for i in range(len(buf), 0, -1):
 | 
			
		||||
            try:
 | 
			
		||||
                s = buf[:i].decode("utf-8")
 | 
			
		||||
                break
 | 
			
		||||
            except UnicodeDecodeError as _e:
 | 
			
		||||
                continue
 | 
			
		||||
        text = s.split(prompt)[1]
 | 
			
		||||
        split_text = text
 | 
			
		||||
        if stop != []:
 | 
			
		||||
| 
						 | 
				
			
			@ -289,10 +295,16 @@ class Starcoder(GenerationMixin):
 | 
			
		|||
                }
 | 
			
		||||
            }
 | 
			
		||||
        else:
 | 
			
		||||
            partial_tokens = []
 | 
			
		||||
            for i in range(max_tokens):
 | 
			
		||||
                token = self.forward(prompt_tokens)
 | 
			
		||||
                prompt_tokens.append(token)
 | 
			
		||||
                text = self.detokenize([token]).decode("utf-8", errors="ignore")
 | 
			
		||||
                partial_tokens.append(token)
 | 
			
		||||
                try:
 | 
			
		||||
                    text = self.detokenize(partial_tokens).decode("utf-8")
 | 
			
		||||
                    partial_tokens.clear()
 | 
			
		||||
                except UnicodeDecodeError as _e:
 | 
			
		||||
                    continue
 | 
			
		||||
                if text.endswith("<|endoftext|>"):
 | 
			
		||||
                    print('\n')
 | 
			
		||||
                    return
 | 
			
		||||
| 
						 | 
				
			
			@ -346,10 +358,10 @@ class Starcoder(GenerationMixin):
 | 
			
		|||
        """
 | 
			
		||||
        invalidInputError(self.ctx is not None,
 | 
			
		||||
                          "The attribute `ctx` of `Starcoder` object is None.")
 | 
			
		||||
        output = ""
 | 
			
		||||
        output = bytes()
 | 
			
		||||
        for token in tokens:
 | 
			
		||||
            output += starcoder_detokenize(self.ctx, token)
 | 
			
		||||
        return output.encode('utf-8')
 | 
			
		||||
        return output
 | 
			
		||||
 | 
			
		||||
    def forward(self, input_ids: List[int]) -> int:
 | 
			
		||||
        return starcoder_forward(ctx=self.ctx,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -156,10 +156,9 @@ _lib.tokenize_api.restype = POINTER(c_int)
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
def starcoder_detokenize(ctx: c_void_p,
 | 
			
		||||
                         token_id: c_int) -> str:
 | 
			
		||||
                         token_id: c_int) -> bytes:
 | 
			
		||||
    c_chars = _lib.detokenize_api(ctx, token_id)
 | 
			
		||||
    s = c_chars.decode('utf-8')
 | 
			
		||||
    return s
 | 
			
		||||
    return c_chars
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
_lib.detokenize_api.argtypes = [c_void_p, c_int]
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue