LLM: raise warning instead of error when use unsupported parameters (#8382)
This commit is contained in:
parent
5ad5ac5356
commit
19e19efb4c
2 changed files with 16 additions and 12 deletions
|
|
@ -52,6 +52,7 @@ from bigdl.llm.ggml.model.generation import GenerationMixin
|
||||||
from typing import List, Optional, Generator, Sequence, Union
|
from typing import List, Optional, Generator, Sequence, Union
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
|
||||||
class Bloom(GenerationMixin):
|
class Bloom(GenerationMixin):
|
||||||
|
|
@ -130,8 +131,9 @@ class Bloom(GenerationMixin):
|
||||||
'last_n_tokens_size': 64, 'lora_base': None,
|
'last_n_tokens_size': 64, 'lora_base': None,
|
||||||
'lora_path': None, 'verbose': True}
|
'lora_path': None, 'verbose': True}
|
||||||
for arg in unsupported_arg.keys():
|
for arg in unsupported_arg.keys():
|
||||||
invalidInputError(getattr(self, arg) == unsupported_arg[arg], f"The parameter {arg}"
|
if getattr(self, arg) != unsupported_arg[arg]:
|
||||||
" is temporarily unsupported, please use the default value.")
|
warnings.warn(f"The parameter {arg} is temporarily unsupported, "
|
||||||
|
"please use the default value.")
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
|
|
@ -173,8 +175,8 @@ class Bloom(GenerationMixin):
|
||||||
'repeat_penalty': 1.1, 'top_k': 40, 'tfs_z': 1.0, 'mirostat_mode': 0,
|
'repeat_penalty': 1.1, 'top_k': 40, 'tfs_z': 1.0, 'mirostat_mode': 0,
|
||||||
'mirostat_tau': 5.0, 'mirostat_eta': 0.1}
|
'mirostat_tau': 5.0, 'mirostat_eta': 0.1}
|
||||||
for index in range(len(args)):
|
for index in range(len(args)):
|
||||||
invalidInputError(args[index] == defult_value[unsupported_arg[index]],
|
if args[index] != defult_value[unsupported_arg[index]]:
|
||||||
f"The parameter {unsupported_arg[index]} is temporarily "
|
warnings.warn(f"The parameter {unsupported_arg[index]} is temporarily "
|
||||||
"unsupported, please use the default value.")
|
"unsupported, please use the default value.")
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
|
|
@ -403,8 +405,8 @@ class Bloom(GenerationMixin):
|
||||||
'reset': True, 'frequency_penalty': 0.0, 'presence_penalty': 0.0,
|
'reset': True, 'frequency_penalty': 0.0, 'presence_penalty': 0.0,
|
||||||
'tfs_z': 1.0, 'mirostat_mode': 0, 'mirostat_tau': 5.0, 'mirostat_eta': 0.1}
|
'tfs_z': 1.0, 'mirostat_mode': 0, 'mirostat_tau': 5.0, 'mirostat_eta': 0.1}
|
||||||
for index in range(len(args)):
|
for index in range(len(args)):
|
||||||
invalidInputError(args[index] == defult_value[unsupported_arg[index]],
|
if args[index] != defult_value[unsupported_arg[index]]:
|
||||||
f"The parameter {unsupported_arg[index]} is temporarily "
|
warnings.warn(f"The parameter {unsupported_arg[index]} is temporarily "
|
||||||
"unsupported, please use the default value.")
|
"unsupported, please use the default value.")
|
||||||
|
|
||||||
invalidInputError(self.ctx is not None, "The attribute `ctx` of `Bloom` object is None.")
|
invalidInputError(self.ctx is not None, "The attribute `ctx` of `Bloom` object is None.")
|
||||||
|
|
|
||||||
|
|
@ -53,6 +53,7 @@ from bigdl.llm.ggml.model.generation import GenerationMixin
|
||||||
from typing import List, Optional, Generator, Sequence, Union
|
from typing import List, Optional, Generator, Sequence, Union
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
|
||||||
class Starcoder(GenerationMixin):
|
class Starcoder(GenerationMixin):
|
||||||
|
|
@ -131,8 +132,9 @@ class Starcoder(GenerationMixin):
|
||||||
'last_n_tokens_size': 64, 'lora_base': None,
|
'last_n_tokens_size': 64, 'lora_base': None,
|
||||||
'lora_path': None, 'verbose': True}
|
'lora_path': None, 'verbose': True}
|
||||||
for arg in unsupported_arg.keys():
|
for arg in unsupported_arg.keys():
|
||||||
invalidInputError(getattr(self, arg) == unsupported_arg[arg], f"The parameter {arg}"
|
if getattr(self, arg) != unsupported_arg[arg]:
|
||||||
" is temporarily unsupported, please use the default value.")
|
warnings.warn(f"The parameter {arg} is temporarily unsupported, "
|
||||||
|
"please use the default value.")
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
|
|
@ -174,8 +176,8 @@ class Starcoder(GenerationMixin):
|
||||||
'repeat_penalty': 1.1, 'top_k': 40, 'tfs_z': 1.0, 'mirostat_mode': 0,
|
'repeat_penalty': 1.1, 'top_k': 40, 'tfs_z': 1.0, 'mirostat_mode': 0,
|
||||||
'mirostat_tau': 5.0, 'mirostat_eta': 0.1}
|
'mirostat_tau': 5.0, 'mirostat_eta': 0.1}
|
||||||
for index in range(len(args)):
|
for index in range(len(args)):
|
||||||
invalidInputError(args[index] == defult_value[unsupported_arg[index]],
|
if args[index] != defult_value[unsupported_arg[index]]:
|
||||||
f"The parameter {unsupported_arg[index]} is temporarily "
|
warnings.warn(f"The parameter {unsupported_arg[index]} is temporarily "
|
||||||
"unsupported, please use the default value.")
|
"unsupported, please use the default value.")
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
|
|
@ -407,8 +409,8 @@ class Starcoder(GenerationMixin):
|
||||||
'reset': True, 'frequency_penalty': 0.0, 'presence_penalty': 0.0,
|
'reset': True, 'frequency_penalty': 0.0, 'presence_penalty': 0.0,
|
||||||
'tfs_z': 1.0, 'mirostat_mode': 0, 'mirostat_tau': 5.0, 'mirostat_eta': 0.1}
|
'tfs_z': 1.0, 'mirostat_mode': 0, 'mirostat_tau': 5.0, 'mirostat_eta': 0.1}
|
||||||
for index in range(len(args)):
|
for index in range(len(args)):
|
||||||
invalidInputError(args[index] == defult_value[unsupported_arg[index]],
|
if args[index] != defult_value[unsupported_arg[index]]:
|
||||||
f"The parameter {unsupported_arg[index]} is temporarily "
|
warnings.warn(f"The parameter {unsupported_arg[index]} is temporarily "
|
||||||
"unsupported, please use the default value.")
|
"unsupported, please use the default value.")
|
||||||
|
|
||||||
invalidInputError(self.ctx is not None,
|
invalidInputError(self.ctx is not None,
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue