fix starcoder chinese output (#8773)
This commit is contained in:
parent
548f7a6cf7
commit
2ba2133613
2 changed files with 19 additions and 8 deletions
|
|
@ -227,8 +227,14 @@ class Starcoder(GenerationMixin):
|
||||||
match_str=match_str,
|
match_str=match_str,
|
||||||
prompt=bytes(prompt, encoding='utf-8'),
|
prompt=bytes(prompt, encoding='utf-8'),
|
||||||
buf=buf)
|
buf=buf)
|
||||||
s = str(buf, encoding='utf-8').rstrip("\x00")
|
buf = buf.rstrip(b"\x00")
|
||||||
|
s = ''
|
||||||
|
for i in range(len(buf), 0, -1):
|
||||||
|
try:
|
||||||
|
s = buf[:i].decode("utf-8")
|
||||||
|
break
|
||||||
|
except UnicodeDecodeError as _e:
|
||||||
|
continue
|
||||||
text = s.split(prompt)[1]
|
text = s.split(prompt)[1]
|
||||||
split_text = text
|
split_text = text
|
||||||
if stop != []:
|
if stop != []:
|
||||||
|
|
@ -289,10 +295,16 @@ class Starcoder(GenerationMixin):
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
|
partial_tokens = []
|
||||||
for i in range(max_tokens):
|
for i in range(max_tokens):
|
||||||
token = self.forward(prompt_tokens)
|
token = self.forward(prompt_tokens)
|
||||||
prompt_tokens.append(token)
|
prompt_tokens.append(token)
|
||||||
text = self.detokenize([token]).decode("utf-8", errors="ignore")
|
partial_tokens.append(token)
|
||||||
|
try:
|
||||||
|
text = self.detokenize(partial_tokens).decode("utf-8")
|
||||||
|
partial_tokens.clear()
|
||||||
|
except UnicodeDecodeError as _e:
|
||||||
|
continue
|
||||||
if text.endswith("<|endoftext|>"):
|
if text.endswith("<|endoftext|>"):
|
||||||
print('\n')
|
print('\n')
|
||||||
return
|
return
|
||||||
|
|
@ -346,10 +358,10 @@ class Starcoder(GenerationMixin):
|
||||||
"""
|
"""
|
||||||
invalidInputError(self.ctx is not None,
|
invalidInputError(self.ctx is not None,
|
||||||
"The attribute `ctx` of `Starcoder` object is None.")
|
"The attribute `ctx` of `Starcoder` object is None.")
|
||||||
output = ""
|
output = bytes()
|
||||||
for token in tokens:
|
for token in tokens:
|
||||||
output += starcoder_detokenize(self.ctx, token)
|
output += starcoder_detokenize(self.ctx, token)
|
||||||
return output.encode('utf-8')
|
return output
|
||||||
|
|
||||||
def forward(self, input_ids: List[int]) -> int:
|
def forward(self, input_ids: List[int]) -> int:
|
||||||
return starcoder_forward(ctx=self.ctx,
|
return starcoder_forward(ctx=self.ctx,
|
||||||
|
|
|
||||||
|
|
@ -156,10 +156,9 @@ _lib.tokenize_api.restype = POINTER(c_int)
|
||||||
|
|
||||||
|
|
||||||
def starcoder_detokenize(ctx: c_void_p,
|
def starcoder_detokenize(ctx: c_void_p,
|
||||||
token_id: c_int) -> str:
|
token_id: c_int) -> bytes:
|
||||||
c_chars = _lib.detokenize_api(ctx, token_id)
|
c_chars = _lib.detokenize_api(ctx, token_id)
|
||||||
s = c_chars.decode('utf-8')
|
return c_chars
|
||||||
return s
|
|
||||||
|
|
||||||
|
|
||||||
_lib.detokenize_api.argtypes = [c_void_p, c_int]
|
_lib.detokenize_api.argtypes = [c_void_p, c_int]
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue