fix starcoder chinese output (#8773)
This commit is contained in:
parent
548f7a6cf7
commit
2ba2133613
2 changed files with 19 additions and 8 deletions
|
|
@ -227,8 +227,14 @@ class Starcoder(GenerationMixin):
|
|||
match_str=match_str,
|
||||
prompt=bytes(prompt, encoding='utf-8'),
|
||||
buf=buf)
|
||||
s = str(buf, encoding='utf-8').rstrip("\x00")
|
||||
|
||||
buf = buf.rstrip(b"\x00")
|
||||
s = ''
|
||||
for i in range(len(buf), 0, -1):
|
||||
try:
|
||||
s = buf[:i].decode("utf-8")
|
||||
break
|
||||
except UnicodeDecodeError as _e:
|
||||
continue
|
||||
text = s.split(prompt)[1]
|
||||
split_text = text
|
||||
if stop != []:
|
||||
|
|
@ -289,10 +295,16 @@ class Starcoder(GenerationMixin):
|
|||
}
|
||||
}
|
||||
else:
|
||||
partial_tokens = []
|
||||
for i in range(max_tokens):
|
||||
token = self.forward(prompt_tokens)
|
||||
prompt_tokens.append(token)
|
||||
text = self.detokenize([token]).decode("utf-8", errors="ignore")
|
||||
partial_tokens.append(token)
|
||||
try:
|
||||
text = self.detokenize(partial_tokens).decode("utf-8")
|
||||
partial_tokens.clear()
|
||||
except UnicodeDecodeError as _e:
|
||||
continue
|
||||
if text.endswith("<|endoftext|>"):
|
||||
print('\n')
|
||||
return
|
||||
|
|
@ -346,10 +358,10 @@ class Starcoder(GenerationMixin):
|
|||
"""
|
||||
invalidInputError(self.ctx is not None,
|
||||
"The attribute `ctx` of `Starcoder` object is None.")
|
||||
output = ""
|
||||
output = bytes()
|
||||
for token in tokens:
|
||||
output += starcoder_detokenize(self.ctx, token)
|
||||
return output.encode('utf-8')
|
||||
return output
|
||||
|
||||
def forward(self, input_ids: List[int]) -> int:
|
||||
return starcoder_forward(ctx=self.ctx,
|
||||
|
|
|
|||
|
|
@ -156,10 +156,9 @@ _lib.tokenize_api.restype = POINTER(c_int)
|
|||
|
||||
|
||||
def starcoder_detokenize(ctx: c_void_p,
|
||||
token_id: c_int) -> str:
|
||||
token_id: c_int) -> bytes:
|
||||
c_chars = _lib.detokenize_api(ctx, token_id)
|
||||
s = c_chars.decode('utf-8')
|
||||
return s
|
||||
return c_chars
|
||||
|
||||
|
||||
_lib.detokenize_api.argtypes = [c_void_p, c_int]
|
||||
|
|
|
|||
Loading…
Reference in a new issue