LLM: raise warning instead of error when use unsupported parameters (#8382)

This commit is contained in:
binbin Deng 2023-06-26 13:23:55 +08:00 committed by GitHub
parent 5ad5ac5356
commit 19e19efb4c
2 changed files with 16 additions and 12 deletions

View file

@ -52,6 +52,7 @@ from bigdl.llm.ggml.model.generation import GenerationMixin
from typing import List, Optional, Generator, Sequence, Union
import time
import uuid
import warnings
class Bloom(GenerationMixin):
@ -130,8 +131,9 @@ class Bloom(GenerationMixin):
'last_n_tokens_size': 64, 'lora_base': None,
'lora_path': None, 'verbose': True}
for arg in unsupported_arg.keys():
invalidInputError(getattr(self, arg) == unsupported_arg[arg], f"The parameter {arg}"
" is temporarily unsupported, please use the default value.")
if getattr(self, arg) != unsupported_arg[arg]:
warnings.warn(f"The parameter {arg} is temporarily unsupported, "
"please use the default value.")
def __call__(
self,
@ -173,8 +175,8 @@ class Bloom(GenerationMixin):
'repeat_penalty': 1.1, 'top_k': 40, 'tfs_z': 1.0, 'mirostat_mode': 0,
'mirostat_tau': 5.0, 'mirostat_eta': 0.1}
for index in range(len(args)):
invalidInputError(args[index] == defult_value[unsupported_arg[index]],
f"The parameter {unsupported_arg[index]} is temporarily "
if args[index] != defult_value[unsupported_arg[index]]:
warnings.warn(f"The parameter {unsupported_arg[index]} is temporarily "
"unsupported, please use the default value.")
if stream:
@ -403,8 +405,8 @@ class Bloom(GenerationMixin):
'reset': True, 'frequency_penalty': 0.0, 'presence_penalty': 0.0,
'tfs_z': 1.0, 'mirostat_mode': 0, 'mirostat_tau': 5.0, 'mirostat_eta': 0.1}
for index in range(len(args)):
invalidInputError(args[index] == defult_value[unsupported_arg[index]],
f"The parameter {unsupported_arg[index]} is temporarily "
if args[index] != defult_value[unsupported_arg[index]]:
warnings.warn(f"The parameter {unsupported_arg[index]} is temporarily "
"unsupported, please use the default value.")
invalidInputError(self.ctx is not None, "The attribute `ctx` of `Bloom` object is None.")

View file

@ -53,6 +53,7 @@ from bigdl.llm.ggml.model.generation import GenerationMixin
from typing import List, Optional, Generator, Sequence, Union
import time
import uuid
import warnings
class Starcoder(GenerationMixin):
@ -131,8 +132,9 @@ class Starcoder(GenerationMixin):
'last_n_tokens_size': 64, 'lora_base': None,
'lora_path': None, 'verbose': True}
for arg in unsupported_arg.keys():
invalidInputError(getattr(self, arg) == unsupported_arg[arg], f"The parameter {arg}"
" is temporarily unsupported, please use the default value.")
if getattr(self, arg) != unsupported_arg[arg]:
warnings.warn(f"The parameter {arg} is temporarily unsupported, "
"please use the default value.")
def __call__(
self,
@ -174,8 +176,8 @@ class Starcoder(GenerationMixin):
'repeat_penalty': 1.1, 'top_k': 40, 'tfs_z': 1.0, 'mirostat_mode': 0,
'mirostat_tau': 5.0, 'mirostat_eta': 0.1}
for index in range(len(args)):
invalidInputError(args[index] == defult_value[unsupported_arg[index]],
f"The parameter {unsupported_arg[index]} is temporarily "
if args[index] != defult_value[unsupported_arg[index]]:
warnings.warn(f"The parameter {unsupported_arg[index]} is temporarily "
"unsupported, please use the default value.")
if stream:
@ -407,8 +409,8 @@ class Starcoder(GenerationMixin):
'reset': True, 'frequency_penalty': 0.0, 'presence_penalty': 0.0,
'tfs_z': 1.0, 'mirostat_mode': 0, 'mirostat_tau': 5.0, 'mirostat_eta': 0.1}
for index in range(len(args)):
invalidInputError(args[index] == defult_value[unsupported_arg[index]],
f"The parameter {unsupported_arg[index]} is temporarily "
if args[index] != defult_value[unsupported_arg[index]]:
warnings.warn(f"The parameter {unsupported_arg[index]} is temporarily "
"unsupported, please use the default value.")
invalidInputError(self.ctx is not None,