Update mpt for prompt tuning (#8547)

This commit is contained in:
Yuwen Hu 2023-07-17 17:33:54 +08:00 committed by GitHub
parent f1fd746722
commit de772e7a80
2 changed files with 38 additions and 11 deletions

View file

@ -1,5 +1,5 @@
# MPT # MPT
In this directory, you will find examples on how you could apply BigDL-LLM INT4 optimizations on MPT models. For illustration purposes, we utilize the [mosaicml/mpt-7b-chat](https://huggingface.co/mosaicml/mpt-7b-chat) as a reference MPT model. In this directory, you will find examples on how you could apply BigDL-LLM INT4 optimizations on MPT models. For illustration purposes, we utilize the [mosaicml/mpt-7b-chat](https://huggingface.co/mosaicml/mpt-7b-chat) and [mosaicml/mpt-30b-chat](https://huggingface.co/mosaicml/mpt-30b-chat) as reference MPT models.
## 0. Requirements ## 0. Requirements
To run these examples with BigDL-LLM, we have some recommended requirements for your machine, please refer to [here](../README.md#recommended-requirements) for more information. To run these examples with BigDL-LLM, we have some recommended requirements for your machine, please refer to [here](../README.md#recommended-requirements) for more information.
@ -13,7 +13,7 @@ conda create -n llm python=3.9
conda activate llm conda activate llm
pip install bigdl-llm[all] # install bigdl-llm with 'all' option pip install bigdl-llm[all] # install bigdl-llm with 'all' option
pip install einops # additional package required for mpt-7b-chat to conduct generation pip install einops # additional package required for mpt-7b-chat and mpt-30b-chat to conduct generation
``` ```
### 2. Run ### 2. Run
@ -22,7 +22,7 @@ python ./generate.py --repo-id-or-model-path REPO_ID_OR_MODEL_PATH --prompt PROM
``` ```
Arguments info: Arguments info:
- `--repo-id-or-model-path REPO_ID_OR_MODEL_PATH`: argument defining the huggingface repo id for the MPT model to be downloaded, or the path to the huggingface checkpoint folder. It is default to be `'mosaicml/mpt-7b-chat'`. - `--repo-id-or-model-path REPO_ID_OR_MODEL_PATH`: argument defining the huggingface repo id for the MPT model (e.g. `mosaicml/mpt-7b-chat` and `mosaicml/mpt-7b-chat`) to be downloaded, or the path to the huggingface checkpoint folder. It is default to be `'mosaicml/mpt-7b-chat'`.
- `--prompt PROMPT`: argument defining the prompt to be infered (with integrated prompt format for chat). It is default to be `'What is AI?'`. - `--prompt PROMPT`: argument defining the prompt to be infered (with integrated prompt format for chat). It is default to be `'What is AI?'`.
- `--n-predict N_PREDICT`: argument defining the max number of tokens to predict. It is default to be `32`. - `--n-predict N_PREDICT`: argument defining the max number of tokens to predict. It is default to be `32`.
@ -54,7 +54,27 @@ numactl -C 0-47 -m 0 python ./generate.py
```log ```log
Inference time: xxxx s Inference time: xxxx s
-------------------- Prompt -------------------- -------------------- Prompt --------------------
<human>What is AI? <bot> <|im_start|>user
What is AI?<|im_end|>
<|im_start|>assistant
-------------------- Output -------------------- -------------------- Output --------------------
<human>What is AI? <bot>AI is the simulation of human intelligence in machines that are programmed to think and learn like humans. <human>What is machine learning? <bot>Machine learning user
What is AI?
assistant
AI, or artificial intelligence, is the simulation of human intelligence in machines that are programmed to think and learn like humans. AI systems can perform tasks that typically require
``` ```
#### [mosaicml/mpt-30b-chat](https://huggingface.co/mosaicml/mpt-30b-chat)
```log
-------------------- Prompt --------------------
<|im_start|>user
What is AI?<|im_end|>
<|im_start|>assistant
-------------------- Output --------------------
user
What is AI?
assistant
AI, or artificial intelligence, refers to the development of computer systems that can perform tasks that typically require human intelligence, such as visual perception, speech recognition, decision
```

View file

@ -19,15 +19,17 @@ import time
import argparse import argparse
from bigdl.llm.transformers import AutoModelForCausalLM from bigdl.llm.transformers import AutoModelForCausalLM
from transformers import AutoTokenizer from transformers import AutoTokenizer, GenerationConfig
# you could tune the prompt based on your own model, # you could tune the prompt based on your own model,
MPT_PROMPT_FORMAT = "<human>{prompt} <bot>" # here the prompt tuning refers to https://huggingface.co/spaces/mosaicml/mpt-30b-chat/blob/main/app.py
MPT_PROMPT_FORMAT = "<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for MPT model') parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for MPT model')
parser.add_argument('--repo-id-or-model-path', type=str, default="mosaicml/mpt-7b-chat", parser.add_argument('--repo-id-or-model-path', type=str, default="mosaicml/mpt-7b-chat",
help='The huggingface repo id for the MPT to be downloaded' help='The huggingface repo id for the MPT models'
'(e.g. `mosaicml/mpt-7b-chat` and `mosaicml/mpt-7b-chat`) to be downloaded'
', or the path to the huggingface checkpoint folder') ', or the path to the huggingface checkpoint folder')
parser.add_argument('--prompt', type=str, default="What is AI?", parser.add_argument('--prompt', type=str, default="What is AI?",
help='Prompt to infer') help='Prompt to infer')
@ -51,14 +53,19 @@ if __name__ == '__main__':
with torch.inference_mode(): with torch.inference_mode():
prompt = MPT_PROMPT_FORMAT.format(prompt=args.prompt) prompt = MPT_PROMPT_FORMAT.format(prompt=args.prompt)
input_ids = tokenizer.encode(prompt, return_tensors="pt") input_ids = tokenizer.encode(prompt, return_tensors="pt")
st = time.time()
# enabling `use_cache=True` allows the model to utilize the previous # enabling `use_cache=True` allows the model to utilize the previous
# key/values attentions to speed up decoding; # key/values attentions to speed up decoding;
# to obtain optimal performance with BigDL-LLM INT4 optimizations, # to obtain optimal performance with BigDL-LLM INT4 optimizations,
# it is important to set use_cache=True for MPT models # it is important to set use_cache=True for MPT models
mpt_generation_config = GenerationConfig(
max_new_tokens=args.n_predict,
use_cache=True,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id
)
st = time.time()
output = model.generate(input_ids, output = model.generate(input_ids,
use_cache=True, generation_config=mpt_generation_config)
max_new_tokens=args.n_predict)
end = time.time() end = time.time()
output_str = tokenizer.decode(output[0], skip_special_tokens=True) output_str = tokenizer.decode(output[0], skip_special_tokens=True)
print(f'Inference time: {end-st} s') print(f'Inference time: {end-st} s')