Update Llama2 multi-processes example (#11852)
* update llama2 multi-processes examples * update * update readme * update
This commit is contained in:
parent
99b05ba1dc
commit
7380823f3f
2 changed files with 33 additions and 14 deletions
|
|
@ -124,17 +124,27 @@ python llama2.py
|
||||||
|
|
||||||
Arguments info:
|
Arguments info:
|
||||||
- `--repo-id-or-model-path REPO_ID_OR_MODEL_PATH`: argument defining the huggingface repo id for the Llama2 model (i.e. `meta-llama/Llama-2-7b-chat-hf`) to be downloaded, or the path to the huggingface checkpoint folder. It is default to be `'meta-llama/Llama-2-7b-chat-hf'`.
|
- `--repo-id-or-model-path REPO_ID_OR_MODEL_PATH`: argument defining the huggingface repo id for the Llama2 model (i.e. `meta-llama/Llama-2-7b-chat-hf`) to be downloaded, or the path to the huggingface checkpoint folder. It is default to be `'meta-llama/Llama-2-7b-chat-hf'`.
|
||||||
|
- `--prompt PROMPT`: argument defining the prompt to be infered (with integrated prompt format for chat). It is default to be `What is AI?`.
|
||||||
- `--n-predict N_PREDICT`: argument defining the max number of tokens to predict. It is default to be `32`.
|
- `--n-predict N_PREDICT`: argument defining the max number of tokens to predict. It is default to be `32`.
|
||||||
|
- `--max-output-len MAX_OUTPUT_LEN`: Defines the maximum sequence length for both input and output tokens. It is default to be `1024`.
|
||||||
|
- `--max-prompt-len MAX_PROMPT_LEN`: Defines the maximum number of tokens that the input prompt can contain. It is default to be `768`.
|
||||||
|
|
||||||
|
|
||||||
#### Sample Output
|
#### Sample Output
|
||||||
#### [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf)
|
#### [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf)
|
||||||
|
|
||||||
```log
|
```log
|
||||||
Inference time: xxxx s
|
Inference time: xxxx s
|
||||||
-------------------- Prompt --------------------
|
-------------------- Input --------------------
|
||||||
Once upon a time, there existed a little girl who liked to have adventures. She wanted to go to places and meet new people, and have fun
|
<s><s> [INST] <<SYS>>
|
||||||
-------------------- Output --------------------
|
|
||||||
<s> Once upon a time, there existed a little girl who liked to have adventures. She wanted to go to places and meet new people, and have fun and exciting experiences.
|
|
||||||
|
|
||||||
One day, she decided to go on a journey to find a magical land that was said to be full of wonders
|
<</SYS>>
|
||||||
|
|
||||||
|
What is AI? [/INST]
|
||||||
|
-------------------- Output --------------------
|
||||||
|
<s><s> [INST] <<SYS>>
|
||||||
|
|
||||||
|
<</SYS>>
|
||||||
|
|
||||||
|
What is AI? [/INST] AI (Artificial Intelligence) is a field of computer science and engineering that focuses on the development of intelligent machines that can perform tasks
|
||||||
```
|
```
|
||||||
|
|
|
||||||
|
|
@ -26,6 +26,18 @@ from transformers.utils import logging
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
def get_prompt(message: str, chat_history: list[tuple[str, str]],
|
||||||
|
system_prompt: str) -> str:
|
||||||
|
texts = [f'<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n']
|
||||||
|
# The first user input is _not_ stripped
|
||||||
|
do_strip = False
|
||||||
|
for user_input, response in chat_history:
|
||||||
|
user_input = user_input.strip() if do_strip else user_input
|
||||||
|
do_strip = True
|
||||||
|
texts.append(f'{user_input} [/INST] {response.strip()} </s><s>[INST] ')
|
||||||
|
message = message.strip() if do_strip else message
|
||||||
|
texts.append(f'{message} [/INST]')
|
||||||
|
return ''.join(texts)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
|
|
@ -38,9 +50,11 @@ if __name__ == "__main__":
|
||||||
help="The huggingface repo id for the Llama2 model to be downloaded"
|
help="The huggingface repo id for the Llama2 model to be downloaded"
|
||||||
", or the path to the huggingface checkpoint folder",
|
", or the path to the huggingface checkpoint folder",
|
||||||
)
|
)
|
||||||
|
parser.add_argument('--prompt', type=str, default="What is AI?",
|
||||||
|
help='Prompt to infer')
|
||||||
parser.add_argument("--n-predict", type=int, default=32, help="Max tokens to predict")
|
parser.add_argument("--n-predict", type=int, default=32, help="Max tokens to predict")
|
||||||
parser.add_argument("--max-output-len", type=int, default=1024)
|
parser.add_argument("--max-output-len", type=int, default=1024)
|
||||||
parser.add_argument("--max-prompt-len", type=int, default=128)
|
parser.add_argument("--max-prompt-len", type=int, default=768)
|
||||||
parser.add_argument("--disable-transpose-value-cache", action="store_true", default=False)
|
parser.add_argument("--disable-transpose-value-cache", action="store_true", default=False)
|
||||||
parser.add_argument("--intra-pp", type=int, default=2)
|
parser.add_argument("--intra-pp", type=int, default=2)
|
||||||
parser.add_argument("--inter-pp", type=int, default=2)
|
parser.add_argument("--inter-pp", type=int, default=2)
|
||||||
|
|
@ -64,20 +78,15 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||||
|
|
||||||
prompts = [
|
DEFAULT_SYSTEM_PROMPT = """\
|
||||||
"Once upon a time, there existed a little girl who liked to have adventures. She wanted to go to places and meet new people, and have fun",
|
"""
|
||||||
"Once upon a time, there existed",
|
|
||||||
"Once upon a time, there existed a little girl who liked to have adventures.",
|
|
||||||
]
|
|
||||||
|
|
||||||
print("-" * 80)
|
print("-" * 80)
|
||||||
print("done")
|
print("done")
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
print("finish to load")
|
print("finish to load")
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
import random
|
prompt = get_prompt(args.prompt, [], system_prompt=DEFAULT_SYSTEM_PROMPT)
|
||||||
idx = random.randint(0, 2)
|
|
||||||
prompt = prompts[idx]
|
|
||||||
_input_ids = tokenizer.encode(prompt, return_tensors="pt")
|
_input_ids = tokenizer.encode(prompt, return_tensors="pt")
|
||||||
print("input length:", len(_input_ids[0]))
|
print("input length:", len(_input_ids[0]))
|
||||||
st = time.time()
|
st = time.time()
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue