diff --git a/python/llm/src/bigdl/llm/ggml/model/starcoder/starcoder.py b/python/llm/src/bigdl/llm/ggml/model/starcoder/starcoder.py index 3432a180..2a26d94e 100644 --- a/python/llm/src/bigdl/llm/ggml/model/starcoder/starcoder.py +++ b/python/llm/src/bigdl/llm/ggml/model/starcoder/starcoder.py @@ -227,8 +227,14 @@ class Starcoder(GenerationMixin): match_str=match_str, prompt=bytes(prompt, encoding='utf-8'), buf=buf) - s = str(buf, encoding='utf-8').rstrip("\x00") - + buf = buf.rstrip(b"\x00") + s = '' + for i in range(len(buf), 0, -1): + try: + s = buf[:i].decode("utf-8") + break + except UnicodeDecodeError as _e: + continue text = s.split(prompt)[1] split_text = text if stop != []: @@ -289,10 +295,16 @@ class Starcoder(GenerationMixin): } } else: + partial_tokens = [] for i in range(max_tokens): token = self.forward(prompt_tokens) prompt_tokens.append(token) - text = self.detokenize([token]).decode("utf-8", errors="ignore") + partial_tokens.append(token) + try: + text = self.detokenize(partial_tokens).decode("utf-8") + partial_tokens.clear() + except UnicodeDecodeError as _e: + continue if text.endswith("<|endoftext|>"): print('\n') return @@ -346,10 +358,10 @@ class Starcoder(GenerationMixin): """ invalidInputError(self.ctx is not None, "The attribute `ctx` of `Starcoder` object is None.") - output = "" + output = bytes() for token in tokens: output += starcoder_detokenize(self.ctx, token) - return output.encode('utf-8') + return output def forward(self, input_ids: List[int]) -> int: return starcoder_forward(ctx=self.ctx, diff --git a/python/llm/src/bigdl/llm/ggml/model/starcoder/starcoder_cpp.py b/python/llm/src/bigdl/llm/ggml/model/starcoder/starcoder_cpp.py index d2cb9631..2b0d80ee 100644 --- a/python/llm/src/bigdl/llm/ggml/model/starcoder/starcoder_cpp.py +++ b/python/llm/src/bigdl/llm/ggml/model/starcoder/starcoder_cpp.py @@ -156,10 +156,9 @@ _lib.tokenize_api.restype = POINTER(c_int) def starcoder_detokenize(ctx: c_void_p, - token_id: c_int) -> str: + token_id: c_int) -> bytes: c_chars = _lib.detokenize_api(ctx, token_id) - s = c_chars.decode('utf-8') - return s + return c_chars _lib.detokenize_api.argtypes = [c_void_p, c_int]