diff --git a/python/llm/src/bigdl/llm/ggml/model/bloom/bloom.py b/python/llm/src/bigdl/llm/ggml/model/bloom/bloom.py index eacd808f..9e7af7a7 100644 --- a/python/llm/src/bigdl/llm/ggml/model/bloom/bloom.py +++ b/python/llm/src/bigdl/llm/ggml/model/bloom/bloom.py @@ -52,6 +52,7 @@ from bigdl.llm.ggml.model.generation import GenerationMixin from typing import List, Optional, Generator, Sequence, Union import time import uuid +import warnings class Bloom(GenerationMixin): @@ -130,8 +131,9 @@ class Bloom(GenerationMixin): 'last_n_tokens_size': 64, 'lora_base': None, 'lora_path': None, 'verbose': True} for arg in unsupported_arg.keys(): - invalidInputError(getattr(self, arg) == unsupported_arg[arg], f"The parameter {arg}" - " is temporarily unsupported, please use the default value.") + if getattr(self, arg) != unsupported_arg[arg]: + warnings.warn(f"The parameter {arg} is temporarily unsupported, " + "please use the default value.") def __call__( self, @@ -173,8 +175,8 @@ class Bloom(GenerationMixin): 'repeat_penalty': 1.1, 'top_k': 40, 'tfs_z': 1.0, 'mirostat_mode': 0, 'mirostat_tau': 5.0, 'mirostat_eta': 0.1} for index in range(len(args)): - invalidInputError(args[index] == defult_value[unsupported_arg[index]], - f"The parameter {unsupported_arg[index]} is temporarily " + if args[index] != defult_value[unsupported_arg[index]]: + warnings.warn(f"The parameter {unsupported_arg[index]} is temporarily " "unsupported, please use the default value.") if stream: @@ -403,8 +405,8 @@ class Bloom(GenerationMixin): 'reset': True, 'frequency_penalty': 0.0, 'presence_penalty': 0.0, 'tfs_z': 1.0, 'mirostat_mode': 0, 'mirostat_tau': 5.0, 'mirostat_eta': 0.1} for index in range(len(args)): - invalidInputError(args[index] == defult_value[unsupported_arg[index]], - f"The parameter {unsupported_arg[index]} is temporarily " + if args[index] != defult_value[unsupported_arg[index]]: + warnings.warn(f"The parameter {unsupported_arg[index]} is temporarily " "unsupported, please use the default value.") invalidInputError(self.ctx is not None, "The attribute `ctx` of `Bloom` object is None.") diff --git a/python/llm/src/bigdl/llm/ggml/model/starcoder/starcoder.py b/python/llm/src/bigdl/llm/ggml/model/starcoder/starcoder.py index b8d20a71..ff480b9a 100644 --- a/python/llm/src/bigdl/llm/ggml/model/starcoder/starcoder.py +++ b/python/llm/src/bigdl/llm/ggml/model/starcoder/starcoder.py @@ -53,6 +53,7 @@ from bigdl.llm.ggml.model.generation import GenerationMixin from typing import List, Optional, Generator, Sequence, Union import time import uuid +import warnings class Starcoder(GenerationMixin): @@ -131,8 +132,9 @@ class Starcoder(GenerationMixin): 'last_n_tokens_size': 64, 'lora_base': None, 'lora_path': None, 'verbose': True} for arg in unsupported_arg.keys(): - invalidInputError(getattr(self, arg) == unsupported_arg[arg], f"The parameter {arg}" - " is temporarily unsupported, please use the default value.") + if getattr(self, arg) != unsupported_arg[arg]: + warnings.warn(f"The parameter {arg} is temporarily unsupported, " + "please use the default value.") def __call__( self, @@ -174,8 +176,8 @@ class Starcoder(GenerationMixin): 'repeat_penalty': 1.1, 'top_k': 40, 'tfs_z': 1.0, 'mirostat_mode': 0, 'mirostat_tau': 5.0, 'mirostat_eta': 0.1} for index in range(len(args)): - invalidInputError(args[index] == defult_value[unsupported_arg[index]], - f"The parameter {unsupported_arg[index]} is temporarily " + if args[index] != defult_value[unsupported_arg[index]]: + warnings.warn(f"The parameter {unsupported_arg[index]} is temporarily " "unsupported, please use the default value.") if stream: @@ -407,8 +409,8 @@ class Starcoder(GenerationMixin): 'reset': True, 'frequency_penalty': 0.0, 'presence_penalty': 0.0, 'tfs_z': 1.0, 'mirostat_mode': 0, 'mirostat_tau': 5.0, 'mirostat_eta': 0.1} for index in range(len(args)): - invalidInputError(args[index] == defult_value[unsupported_arg[index]], - f"The parameter {unsupported_arg[index]} is temporarily " + if args[index] != defult_value[unsupported_arg[index]]: + warnings.warn(f"The parameter {unsupported_arg[index]} is temporarily " "unsupported, please use the default value.") invalidInputError(self.ctx is not None,