LLM: raise warning instead of error when use unsupported parameters (#8382)

This commit is contained in:
binbin Deng 2023-06-26 13:23:55 +08:00 committed by GitHub
parent 5ad5ac5356
commit 19e19efb4c
2 changed files with 16 additions and 12 deletions

View file

@ -52,6 +52,7 @@ from bigdl.llm.ggml.model.generation import GenerationMixin
from typing import List, Optional, Generator, Sequence, Union from typing import List, Optional, Generator, Sequence, Union
import time import time
import uuid import uuid
import warnings
class Bloom(GenerationMixin): class Bloom(GenerationMixin):
@ -130,8 +131,9 @@ class Bloom(GenerationMixin):
'last_n_tokens_size': 64, 'lora_base': None, 'last_n_tokens_size': 64, 'lora_base': None,
'lora_path': None, 'verbose': True} 'lora_path': None, 'verbose': True}
for arg in unsupported_arg.keys(): for arg in unsupported_arg.keys():
invalidInputError(getattr(self, arg) == unsupported_arg[arg], f"The parameter {arg}" if getattr(self, arg) != unsupported_arg[arg]:
" is temporarily unsupported, please use the default value.") warnings.warn(f"The parameter {arg} is temporarily unsupported, "
"please use the default value.")
def __call__( def __call__(
self, self,
@ -173,8 +175,8 @@ class Bloom(GenerationMixin):
'repeat_penalty': 1.1, 'top_k': 40, 'tfs_z': 1.0, 'mirostat_mode': 0, 'repeat_penalty': 1.1, 'top_k': 40, 'tfs_z': 1.0, 'mirostat_mode': 0,
'mirostat_tau': 5.0, 'mirostat_eta': 0.1} 'mirostat_tau': 5.0, 'mirostat_eta': 0.1}
for index in range(len(args)): for index in range(len(args)):
invalidInputError(args[index] == defult_value[unsupported_arg[index]], if args[index] != defult_value[unsupported_arg[index]]:
f"The parameter {unsupported_arg[index]} is temporarily " warnings.warn(f"The parameter {unsupported_arg[index]} is temporarily "
"unsupported, please use the default value.") "unsupported, please use the default value.")
if stream: if stream:
@ -403,8 +405,8 @@ class Bloom(GenerationMixin):
'reset': True, 'frequency_penalty': 0.0, 'presence_penalty': 0.0, 'reset': True, 'frequency_penalty': 0.0, 'presence_penalty': 0.0,
'tfs_z': 1.0, 'mirostat_mode': 0, 'mirostat_tau': 5.0, 'mirostat_eta': 0.1} 'tfs_z': 1.0, 'mirostat_mode': 0, 'mirostat_tau': 5.0, 'mirostat_eta': 0.1}
for index in range(len(args)): for index in range(len(args)):
invalidInputError(args[index] == defult_value[unsupported_arg[index]], if args[index] != defult_value[unsupported_arg[index]]:
f"The parameter {unsupported_arg[index]} is temporarily " warnings.warn(f"The parameter {unsupported_arg[index]} is temporarily "
"unsupported, please use the default value.") "unsupported, please use the default value.")
invalidInputError(self.ctx is not None, "The attribute `ctx` of `Bloom` object is None.") invalidInputError(self.ctx is not None, "The attribute `ctx` of `Bloom` object is None.")

View file

@ -53,6 +53,7 @@ from bigdl.llm.ggml.model.generation import GenerationMixin
from typing import List, Optional, Generator, Sequence, Union from typing import List, Optional, Generator, Sequence, Union
import time import time
import uuid import uuid
import warnings
class Starcoder(GenerationMixin): class Starcoder(GenerationMixin):
@ -131,8 +132,9 @@ class Starcoder(GenerationMixin):
'last_n_tokens_size': 64, 'lora_base': None, 'last_n_tokens_size': 64, 'lora_base': None,
'lora_path': None, 'verbose': True} 'lora_path': None, 'verbose': True}
for arg in unsupported_arg.keys(): for arg in unsupported_arg.keys():
invalidInputError(getattr(self, arg) == unsupported_arg[arg], f"The parameter {arg}" if getattr(self, arg) != unsupported_arg[arg]:
" is temporarily unsupported, please use the default value.") warnings.warn(f"The parameter {arg} is temporarily unsupported, "
"please use the default value.")
def __call__( def __call__(
self, self,
@ -174,8 +176,8 @@ class Starcoder(GenerationMixin):
'repeat_penalty': 1.1, 'top_k': 40, 'tfs_z': 1.0, 'mirostat_mode': 0, 'repeat_penalty': 1.1, 'top_k': 40, 'tfs_z': 1.0, 'mirostat_mode': 0,
'mirostat_tau': 5.0, 'mirostat_eta': 0.1} 'mirostat_tau': 5.0, 'mirostat_eta': 0.1}
for index in range(len(args)): for index in range(len(args)):
invalidInputError(args[index] == defult_value[unsupported_arg[index]], if args[index] != defult_value[unsupported_arg[index]]:
f"The parameter {unsupported_arg[index]} is temporarily " warnings.warn(f"The parameter {unsupported_arg[index]} is temporarily "
"unsupported, please use the default value.") "unsupported, please use the default value.")
if stream: if stream:
@ -407,8 +409,8 @@ class Starcoder(GenerationMixin):
'reset': True, 'frequency_penalty': 0.0, 'presence_penalty': 0.0, 'reset': True, 'frequency_penalty': 0.0, 'presence_penalty': 0.0,
'tfs_z': 1.0, 'mirostat_mode': 0, 'mirostat_tau': 5.0, 'mirostat_eta': 0.1} 'tfs_z': 1.0, 'mirostat_mode': 0, 'mirostat_tau': 5.0, 'mirostat_eta': 0.1}
for index in range(len(args)): for index in range(len(args)):
invalidInputError(args[index] == defult_value[unsupported_arg[index]], if args[index] != defult_value[unsupported_arg[index]]:
f"The parameter {unsupported_arg[index]} is temporarily " warnings.warn(f"The parameter {unsupported_arg[index]} is temporarily "
"unsupported, please use the default value.") "unsupported, please use the default value.")
invalidInputError(self.ctx is not None, invalidInputError(self.ctx is not None,