[WebUI] Add prompt format and stopping words for Qwen (#10066)
* add prompt format and stopping_words for qwen mdoel * performance optimization * optimize * update * meet comments
This commit is contained in:
parent
0aecd8637b
commit
4b02ff188b
2 changed files with 39 additions and 4 deletions
|
|
@ -41,6 +41,18 @@ class _StopEverythingStoppingCriteria(transformers.StoppingCriteria):
|
||||||
return shared.stop_everything
|
return shared.stop_everything
|
||||||
|
|
||||||
|
|
||||||
|
class StopWordsCriteria(transformers.StoppingCriteria):
|
||||||
|
"""Custom `StoppingCriteria` which checks if all generated functions in the batch are completed."""
|
||||||
|
def __init__(self, stop_words, tokenizer):
|
||||||
|
self.stop_words = stop_words
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
|
||||||
|
def __call__(self, input_ids, scores, **kwargs):
|
||||||
|
"""Returns true if all generated sequences contain any of the end-of-function strings."""
|
||||||
|
text = self.tokenizer.decode(input_ids[-1][-1])
|
||||||
|
return text in self.stop_words
|
||||||
|
|
||||||
|
|
||||||
class Stream(transformers.StoppingCriteria):
|
class Stream(transformers.StoppingCriteria):
|
||||||
def __init__(self, callback_func=None):
|
def __init__(self, callback_func=None):
|
||||||
self.callback_func = callback_func
|
self.callback_func = callback_func
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
# This file is adapted from
|
# This file is adapted from
|
||||||
# https://github.com/oobabooga/text-generation-webui/blob/main/modules/text_generation.py
|
# https://github.com/oobabooga/text-generation-webui/blob/main/modules/text_generation.py
|
||||||
|
|
@ -35,7 +35,8 @@ import modules.shared as shared
|
||||||
from modules.callbacks import (
|
from modules.callbacks import (
|
||||||
Iteratorize,
|
Iteratorize,
|
||||||
Stream,
|
Stream,
|
||||||
_StopEverythingStoppingCriteria
|
_StopEverythingStoppingCriteria,
|
||||||
|
StopWordsCriteria
|
||||||
)
|
)
|
||||||
from modules.extensions import apply_extensions
|
from modules.extensions import apply_extensions
|
||||||
from modules.grammar.grammar_utils import initialize_grammar
|
from modules.grammar.grammar_utils import initialize_grammar
|
||||||
|
|
@ -331,6 +332,19 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings
|
||||||
if shared.args.deepspeed:
|
if shared.args.deepspeed:
|
||||||
generate_params.update({'synced_gpus': True})
|
generate_params.update({'synced_gpus': True})
|
||||||
|
|
||||||
|
#tune the prompt based on qwen
|
||||||
|
QWEN_PROMPT_FORMAT = """
|
||||||
|
<|im_start|>system
|
||||||
|
You are a helpful assistant.
|
||||||
|
<|im_end|>
|
||||||
|
<|im_start|>user
|
||||||
|
{prompt}
|
||||||
|
<|im_end|>
|
||||||
|
<|im_start|>assistant
|
||||||
|
"""
|
||||||
|
if shared.model.config.model_type == "qwen":
|
||||||
|
question = QWEN_PROMPT_FORMAT.format(prompt=question)
|
||||||
|
|
||||||
# Encode the input
|
# Encode the input
|
||||||
input_ids = encode(question, add_bos_token=state['add_bos_token'], truncation_length=get_max_prompt_length(state))
|
input_ids = encode(question, add_bos_token=state['add_bos_token'], truncation_length=get_max_prompt_length(state))
|
||||||
output = input_ids[0]
|
output = input_ids[0]
|
||||||
|
|
@ -346,10 +360,19 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings
|
||||||
generate_params.update({'inputs_embeds': inputs_embeds})
|
generate_params.update({'inputs_embeds': inputs_embeds})
|
||||||
|
|
||||||
# Stopping criteria / eos token
|
# Stopping criteria / eos token
|
||||||
|
generate_params['stopping_criteria'] = transformers.StoppingCriteriaList()
|
||||||
eos_token_ids = [shared.tokenizer.eos_token_id] if shared.tokenizer.eos_token_id is not None else []
|
eos_token_ids = [shared.tokenizer.eos_token_id] if shared.tokenizer.eos_token_id is not None else []
|
||||||
generate_params['eos_token_id'] = eos_token_ids
|
generate_params['eos_token_id'] = eos_token_ids
|
||||||
generate_params['stopping_criteria'] = transformers.StoppingCriteriaList()
|
|
||||||
generate_params['stopping_criteria'].append(_StopEverythingStoppingCriteria())
|
if shared.model.config.model_type == "qwen":
|
||||||
|
stopping_words = ["<|endoftext|>", "<|im_end|>", "<|im_start|>"]
|
||||||
|
generate_params['stopping_criteria'].append(StopWordsCriteria(stopping_words, shared.tokenizer))
|
||||||
|
|
||||||
|
for st in state['custom_stopping_strings']:
|
||||||
|
if type(st) is str:
|
||||||
|
stopping_words = [item.strip().strip('"') for item in [state['custom_stopping_strings']][0].split(',')]
|
||||||
|
generate_params['stopping_criteria'].append(StopWordsCriteria(stopping_words, shared.tokenizer))
|
||||||
|
|
||||||
|
|
||||||
# Logits processor
|
# Logits processor
|
||||||
processor = state.get('logits_processor', LogitsProcessorList([]))
|
processor = state.get('logits_processor', LogitsProcessorList([]))
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue