fix starcoder chinese output (#8773)

This commit is contained in:
Yishuo Wang 2023-08-18 13:37:02 +08:00 committed by GitHub
parent 548f7a6cf7
commit 2ba2133613
2 changed files with 19 additions and 8 deletions

View file

@ -227,8 +227,14 @@ class Starcoder(GenerationMixin):
match_str=match_str,
prompt=bytes(prompt, encoding='utf-8'),
buf=buf)
s = str(buf, encoding='utf-8').rstrip("\x00")
buf = buf.rstrip(b"\x00")
s = ''
for i in range(len(buf), 0, -1):
try:
s = buf[:i].decode("utf-8")
break
except UnicodeDecodeError as _e:
continue
text = s.split(prompt)[1]
split_text = text
if stop != []:
@ -289,10 +295,16 @@ class Starcoder(GenerationMixin):
}
}
else:
partial_tokens = []
for i in range(max_tokens):
token = self.forward(prompt_tokens)
prompt_tokens.append(token)
text = self.detokenize([token]).decode("utf-8", errors="ignore")
partial_tokens.append(token)
try:
text = self.detokenize(partial_tokens).decode("utf-8")
partial_tokens.clear()
except UnicodeDecodeError as _e:
continue
if text.endswith("<|endoftext|>"):
print('\n')
return
@ -346,10 +358,10 @@ class Starcoder(GenerationMixin):
"""
invalidInputError(self.ctx is not None,
"The attribute `ctx` of `Starcoder` object is None.")
output = ""
output = bytes()
for token in tokens:
output += starcoder_detokenize(self.ctx, token)
return output.encode('utf-8')
return output
def forward(self, input_ids: List[int]) -> int:
return starcoder_forward(ctx=self.ctx,

View file

@ -156,10 +156,9 @@ _lib.tokenize_api.restype = POINTER(c_int)
def starcoder_detokenize(ctx: c_void_p,
token_id: c_int) -> str:
token_id: c_int) -> bytes:
c_chars = _lib.detokenize_api(ctx, token_id)
s = c_chars.decode('utf-8')
return s
return c_chars
_lib.detokenize_api.argtypes = [c_void_p, c_int]