LLM: raise warning instead of error when use unsupported parameters (#8382)
This commit is contained in:
parent
5ad5ac5356
commit
19e19efb4c
2 changed files with 16 additions and 12 deletions
|
|
@ -52,6 +52,7 @@ from bigdl.llm.ggml.model.generation import GenerationMixin
|
|||
from typing import List, Optional, Generator, Sequence, Union
|
||||
import time
|
||||
import uuid
|
||||
import warnings
|
||||
|
||||
|
||||
class Bloom(GenerationMixin):
|
||||
|
|
@ -130,8 +131,9 @@ class Bloom(GenerationMixin):
|
|||
'last_n_tokens_size': 64, 'lora_base': None,
|
||||
'lora_path': None, 'verbose': True}
|
||||
for arg in unsupported_arg.keys():
|
||||
invalidInputError(getattr(self, arg) == unsupported_arg[arg], f"The parameter {arg}"
|
||||
" is temporarily unsupported, please use the default value.")
|
||||
if getattr(self, arg) != unsupported_arg[arg]:
|
||||
warnings.warn(f"The parameter {arg} is temporarily unsupported, "
|
||||
"please use the default value.")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
|
|
@ -173,8 +175,8 @@ class Bloom(GenerationMixin):
|
|||
'repeat_penalty': 1.1, 'top_k': 40, 'tfs_z': 1.0, 'mirostat_mode': 0,
|
||||
'mirostat_tau': 5.0, 'mirostat_eta': 0.1}
|
||||
for index in range(len(args)):
|
||||
invalidInputError(args[index] == defult_value[unsupported_arg[index]],
|
||||
f"The parameter {unsupported_arg[index]} is temporarily "
|
||||
if args[index] != defult_value[unsupported_arg[index]]:
|
||||
warnings.warn(f"The parameter {unsupported_arg[index]} is temporarily "
|
||||
"unsupported, please use the default value.")
|
||||
|
||||
if stream:
|
||||
|
|
@ -403,8 +405,8 @@ class Bloom(GenerationMixin):
|
|||
'reset': True, 'frequency_penalty': 0.0, 'presence_penalty': 0.0,
|
||||
'tfs_z': 1.0, 'mirostat_mode': 0, 'mirostat_tau': 5.0, 'mirostat_eta': 0.1}
|
||||
for index in range(len(args)):
|
||||
invalidInputError(args[index] == defult_value[unsupported_arg[index]],
|
||||
f"The parameter {unsupported_arg[index]} is temporarily "
|
||||
if args[index] != defult_value[unsupported_arg[index]]:
|
||||
warnings.warn(f"The parameter {unsupported_arg[index]} is temporarily "
|
||||
"unsupported, please use the default value.")
|
||||
|
||||
invalidInputError(self.ctx is not None, "The attribute `ctx` of `Bloom` object is None.")
|
||||
|
|
|
|||
|
|
@ -53,6 +53,7 @@ from bigdl.llm.ggml.model.generation import GenerationMixin
|
|||
from typing import List, Optional, Generator, Sequence, Union
|
||||
import time
|
||||
import uuid
|
||||
import warnings
|
||||
|
||||
|
||||
class Starcoder(GenerationMixin):
|
||||
|
|
@ -131,8 +132,9 @@ class Starcoder(GenerationMixin):
|
|||
'last_n_tokens_size': 64, 'lora_base': None,
|
||||
'lora_path': None, 'verbose': True}
|
||||
for arg in unsupported_arg.keys():
|
||||
invalidInputError(getattr(self, arg) == unsupported_arg[arg], f"The parameter {arg}"
|
||||
" is temporarily unsupported, please use the default value.")
|
||||
if getattr(self, arg) != unsupported_arg[arg]:
|
||||
warnings.warn(f"The parameter {arg} is temporarily unsupported, "
|
||||
"please use the default value.")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
|
|
@ -174,8 +176,8 @@ class Starcoder(GenerationMixin):
|
|||
'repeat_penalty': 1.1, 'top_k': 40, 'tfs_z': 1.0, 'mirostat_mode': 0,
|
||||
'mirostat_tau': 5.0, 'mirostat_eta': 0.1}
|
||||
for index in range(len(args)):
|
||||
invalidInputError(args[index] == defult_value[unsupported_arg[index]],
|
||||
f"The parameter {unsupported_arg[index]} is temporarily "
|
||||
if args[index] != defult_value[unsupported_arg[index]]:
|
||||
warnings.warn(f"The parameter {unsupported_arg[index]} is temporarily "
|
||||
"unsupported, please use the default value.")
|
||||
|
||||
if stream:
|
||||
|
|
@ -407,8 +409,8 @@ class Starcoder(GenerationMixin):
|
|||
'reset': True, 'frequency_penalty': 0.0, 'presence_penalty': 0.0,
|
||||
'tfs_z': 1.0, 'mirostat_mode': 0, 'mirostat_tau': 5.0, 'mirostat_eta': 0.1}
|
||||
for index in range(len(args)):
|
||||
invalidInputError(args[index] == defult_value[unsupported_arg[index]],
|
||||
f"The parameter {unsupported_arg[index]} is temporarily "
|
||||
if args[index] != defult_value[unsupported_arg[index]]:
|
||||
warnings.warn(f"The parameter {unsupported_arg[index]} is temporarily "
|
||||
"unsupported, please use the default value.")
|
||||
|
||||
invalidInputError(self.ctx is not None,
|
||||
|
|
|
|||
Loading…
Reference in a new issue